Skip to content

Commit

Permalink
fix: mypy static analysis issues
Browse files Browse the repository at this point in the history
Co-authored-by: Ghaith Kdimati <[email protected]>"
  • Loading branch information
Ofahmy143 committed Oct 1, 2024
1 parent 7f3b711 commit 6ba0802
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 262 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ upload_to_vcs_release = true





[project]
name = "rls"
version = "0.3.0"
Expand Down
8 changes: 4 additions & 4 deletions rls/alembic_rls.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def disable_rls(operations, operation):

@renderers.dispatch_for(EnableRlsOp)
def render_enable_rls(autogen_context, op):
return "op.enable_rls(%r)" % (op.tablename)
return "op.enable_rls(%r) # type: ignore" % (op.tablename)


@renderers.dispatch_for(DisableRlsOp)
def render_disable_rls(autogen_context, op):
return "op.disable_rls(%r)" % (op.tablename)
return "op.disable_rls(%r) # type: ignore" % (op.tablename)


############################
Expand Down Expand Up @@ -294,12 +294,12 @@ def drop_policy(operations, operation):

@renderers.dispatch_for(CreatePolicyOp)
def render_create_policy(autogen_context, op):
return f"op.create_policy(table_name={op.table_name!r}, policy_name={op.policy_name!r}, cmd={op.cmd!r}, definition='{op.definition}', expr=\"{op.expr}\")"
return f"op.create_policy(table_name={op.table_name!r}, policy_name={op.policy_name!r}, cmd={op.cmd!r}, definition='{op.definition}', expr=\"{op.expr}\") # type: ignore"


@renderers.dispatch_for(DropPolicyOp)
def render_drop_policy(autogen_context, op):
return f"op.drop_policy(tablename={op.table_name!r}, policyname={op.policy_name!r})"
return f"op.drop_policy(tablename={op.table_name!r}, policyname={op.policy_name!r}) # type: ignore"


def set_metadata_info(Base: Type[DeclarativeMeta]):
Expand Down
28 changes: 18 additions & 10 deletions rls/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Literal, Union, TypedDict, Optional, NotRequired
from typing import List, Literal, Union, TypedDict, Optional, NotRequired, Type

from pydantic import BaseModel
from .utils import generate_rls_policy
Expand Down Expand Up @@ -34,9 +34,9 @@ class Operation(str, Enum):

class ConditionArgs(TypedDict):
comparator_name: str
type: sqltypes
operation: NotRequired[Optional[Operation]] = None
column_name: NotRequired[Optional[str]] = None
type: Type[sqltypes.TypeEngine]
operation: NotRequired[Optional[Operation]]
column_name: NotRequired[Optional[str]]


class Policy(BaseModel):
Expand All @@ -47,8 +47,8 @@ class Policy(BaseModel):
custom_expr: Optional[str] = None

__policy_names: List[str] = []
__expr: str = None
__policy_suffix: int = 0
__expr: str = ""
__policy_suffix: str = ""

class Config:
arbitrary_types_allowed = True
Expand All @@ -60,7 +60,15 @@ def _get_safe_variable_name(self, idx: int = 0):
def _get_expr_from_params(self, table_name: str, idx: int = 0):
safe_variable_name = self._get_safe_variable_name(idx=idx)

expr = f"{self.condition_args[idx]['column_name']} {self.condition_args[idx]['operation'].value} {safe_variable_name}"
operation_obj = self.condition_args[idx]

# Check if "operation" exists and is not None before accessing its value
if "operation" in operation_obj and operation_obj["operation"] is not None:
operation_value = operation_obj["operation"].value
else:
operation_value = ""

expr = f"{self.condition_args[idx]['column_name']} {operation_value} {safe_variable_name}"

return expr

Expand All @@ -69,15 +77,15 @@ def _get_expr_from_joined_expr(self, table_name: str):
for idx in range(len(self.condition_args)):
pattern = rf"\{{{idx}\}}" # Escaped curly braces
parsed_expr = self._get_expr_from_params(table_name, idx)
expr = re.sub(pattern, parsed_expr, expr)
expr = re.sub(pattern, parsed_expr, str(expr))
return expr

def _get_expr_from_custom_expr(self, table_name: str):
expr = self.custom_expr
for idx in range(len(self.condition_args)):
safe_variable_name = self._get_safe_variable_name(idx=idx)
pattern = rf"\{{{idx}\}}"
expr = re.sub(pattern, safe_variable_name, expr)
expr = re.sub(pattern, safe_variable_name, str(expr))
return expr

def _validate_joining_operations_in_expr(self):
Expand Down Expand Up @@ -120,7 +128,7 @@ def _validate_state(self):
)

@property
def policy_names(self) -> str:
def policy_names(self) -> list[str]:
"""Getter for the private __policy_name field."""
return self.__policy_names

Expand Down
File renamed without changes.
36 changes: 36 additions & 0 deletions test/alembic/versions/479bd311af91_initial_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""initial updates
Revision ID: 479bd311af91
Revises:
Create Date: 2024-09-30 15:47:22.747484
"""

from typing import Sequence, Union

from alembic import op


# revision identifiers, used by Alembic.
revision: str = "479bd311af91"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_policy(
table_name="items",
policy_name="items_permissive_all_policy_0",
cmd="ALL",
definition="PERMISSIVE",
expr="owner_id > NULLIF(current_setting('rls.account_id', true),'')::INTEGER",
) # type: ignore
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_policy(tablename="items", policyname="items_permissive_all_policy_0") # type: ignore
# ### end Alembic commands ###
67 changes: 0 additions & 67 deletions test/alembic/versions/4e47974b6aaf_initial_updates.py

This file was deleted.

5 changes: 3 additions & 2 deletions test/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from rls.schemas import (
Permissive,
Command,
)


# Base = register_rls(declarative_base())
Base = declarative_base()
Base = declarative_base() # type: Any


class User(Base):
Expand Down
28 changes: 14 additions & 14 deletions test/setup.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# sync_setup.py
from .engines import setup_engine as engine
from .models import RlsBase
# # sync_setup.py
# from .engines import setup_engine as engine
# from .models import RlsBase

print("Setting up database")
# print("Setting up database")


def setup_database():
print("Setting up database")
# Create tables
RlsBase.metadata.create_all(engine)
# def setup_database():
# print("Setting up database")
# # Create tables
# RlsBase.metadata.create_all(engine)


def teardown_database():
print("Tearing down database")
# Drop tables
RlsBase.metadata.drop_all(engine)
# def teardown_database():
# print("Tearing down database")
# # Drop tables
# RlsBase.metadata.drop_all(engine)


if __name__ == "__main__":
setup_database()
# if __name__ == "__main__":
# setup_database()
Loading

0 comments on commit 6ba0802

Please sign in to comment.