Skip to content

Commit

Permalink
fix: Ensure table uniqueness on update (apache#15909)
Browse files Browse the repository at this point in the history
* fix: Ensure table uniqueness on update

* Update models.py

* Update slice.py

* Update datasource_tests.py

Co-authored-by: John Bodley <[email protected]>
  • Loading branch information
2 people authored and cccs-RyanS committed Dec 17, 2021
1 parent 3134185 commit 156fdd8
Show file tree
Hide file tree
Showing 20 changed files with 344 additions and 274 deletions.
3 changes: 3 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ This file documents any backwards-incompatible changes in Superset and
assists people when migrating to a new version.

## Next
- [15909](https://github.com/apache/incubator-superset/pull/15909): a change which
drops a uniqueness criterion (which may or may not have existed) to the tables table. This constraint was obsolete as it is handled by the ORM due to differences in how MySQL, PostgreSQL, etc. handle uniqueness for NULL values.

- [15927](https://github.com/apache/superset/pull/15927): Upgrades Celery to 5.x. Per the [upgrading](https://docs.celeryproject.org/en/stable/history/whatsnew-5.0.html#upgrading-from-celery-4-x) instructions Celery 5.0 introduces a new CLI implementation which is not completely backwards compatible. Please ensure global options are positioned before the sub-command.

- [13772](https://github.com/apache/superset/pull/13772): Row level security (RLS) is now enabled by default. To activate the feature, please run `superset init` to expose the RLS menus to Admin users.
Expand Down
52 changes: 51 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
owner_class = security_manager.user_model

__tablename__ = "tables"
__table_args__ = (UniqueConstraint("database_id", "table_name"),)

# Note this uniqueness constraint is not part of the physical schema, i.e., it does
# not exist in the migrations, but is required by `import_from_dict` to ensure the
# correct filters are applied in order to identify uniqueness.
#
# The reason it does not physically exist is MySQL, PostgreSQL, etc. have a
# different interpretation of uniqueness when it comes to NULL which is problematic
# given the schema is optional.
__table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),)

table_name = Column(String(250), nullable=False)
main_dttm_col = Column(String(250))
Expand Down Expand Up @@ -1606,6 +1614,47 @@ class and any keys added via `ExtraCache`.
extra_cache_keys += sqla_query.extra_cache_keys
return extra_cache_keys

@staticmethod
def before_update(
mapper: Mapper, # pylint: disable=unused-argument
connection: Connection, # pylint: disable=unused-argument
target: "SqlaTable",
) -> None:
"""
Check whether before update if the target table already exists.
Note this listener is called when any fields are being updated and thus it is
necessary to first check whether the reference table is being updated.
Note this logic is temporary, given uniqueness is handled via the dataset DAO,
but is necessary until both the legacy datasource editor and datasource/save
endpoints are deprecated.
:param mapper: The table mapper
:param connection: The DB-API connection
:param target: The mapped instance being persisted
:raises Exception: If the target table is not unique
"""

from superset.datasets.commands.exceptions import get_dataset_exist_error_msg
from superset.datasets.dao import DatasetDAO

# Check whether the relevant attributes have changed.
state = db.inspect(target) # pylint: disable=no-member

for attr in ["database_id", "schema", "table_name"]:
history = state.get_history(attr, True)

if history.has_changes():
break
else:
return None

if not DatasetDAO.validate_uniqueness(
target.database_id, target.schema, target.table_name
):
raise Exception(get_dataset_exist_error_msg(target.full_name))


def update_table(
_mapper: Mapper, _connection: Connection, obj: Union[SqlMetric, TableColumn]
Expand All @@ -1623,6 +1672,7 @@ def update_table(

sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
sa.event.listen(SqlaTable, "after_update", security_manager.set_perm)
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlMetric, "after_update", update_table)
sa.event.listen(TableColumn, "after_update", update_table)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""drop tables constraint
Revision ID: 31b2a1039d4a
Revises: ae1ed299413b
Create Date: 2021-07-27 08:25:20.755453
"""

from alembic import op
from sqlalchemy import engine
from sqlalchemy.exc import OperationalError, ProgrammingError

from superset.utils.core import generic_find_uq_constraint_name

# revision identifiers, used by Alembic.
revision = "31b2a1039d4a"
down_revision = "ae1ed299413b"

conv = {"uq": "uq_%(table_name)s_%(column_0_name)s"}


def upgrade():
bind = op.get_bind()
insp = engine.reflection.Inspector.from_engine(bind)

# Drop the uniqueness constraint if it exists.
constraint = generic_find_uq_constraint_name("tables", {"table_name"}, insp)

if constraint:
with op.batch_alter_table("tables", naming_convention=conv) as batch_op:
batch_op.drop_constraint(constraint, type_="unique")


def downgrade():

# One cannot simply re-add the uniqueness constraint as it may not have previously
# existed.
pass
5 changes: 2 additions & 3 deletions superset/models/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@
logger = logging.getLogger(__name__)


class Slice(
class Slice( # pylint: disable=too-many-instance-attributes,too-many-public-methods
Model, AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods, too-many-instance-attributes

):
"""A slice is essentially a report or a view on data"""

__tablename__ = "slices"
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/access_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_override_role_permissions_1_table(self):

updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_override_role_permissions_druid_and_table(self):
"datasource_access", updated_role.permissions[1].permission.name
)

birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(birth_names.perm, perms[2].view_menu.name)
self.assertEqual(
"datasource_access", updated_role.permissions[2].permission.name
Expand All @@ -204,7 +204,7 @@ def test_override_role_permissions_drops_absent_perms(self):
override_me = security_manager.find_role("override_me")
override_me.permissions.append(
security_manager.find_permission_view_menu(
view_menu_name=self.get_table_by_name("energy_usage").perm,
view_menu_name=self.get_table(name="energy_usage").perm,
permission_name="datasource_access",
)
)
Expand All @@ -218,7 +218,7 @@ def test_override_role_permissions_drops_absent_perms(self):
self.assertEqual(201, response.status_code)
updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
Expand Down
43 changes: 23 additions & 20 deletions tests/integration_tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ def post_assert_metric(
return rv


def get_table_by_name(name: str) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(table_name=name).one()


@pytest.fixture
def logged_in_admin():
"""Fixture with app context and logged in admin user."""
Expand Down Expand Up @@ -132,12 +128,7 @@ def get_nonexistent_numeric_id(model):

@staticmethod
def get_birth_names_dataset() -> SqlaTable:
example_db = get_example_database()
return (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)
return SupersetTestCase.get_table(name="birth_names")

@staticmethod
def create_user_with_roles(
Expand Down Expand Up @@ -254,13 +245,31 @@ def get_slice(
return slc

@staticmethod
def get_table_by_name(name: str) -> SqlaTable:
return get_table_by_name(name)
def get_table(
name: str, database_id: Optional[int] = None, schema: Optional[str] = None
) -> SqlaTable:
return (
db.session.query(SqlaTable)
.filter_by(
database_id=database_id
or SupersetTestCase.get_database_by_name("examples").id,
schema=schema,
table_name=name,
)
.one()
)

@staticmethod
def get_database_by_id(db_id: int) -> Database:
return db.session.query(Database).filter_by(id=db_id).one()

@staticmethod
def get_database_by_name(database_name: str = "main") -> Database:
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")

@staticmethod
def get_druid_ds_by_name(name: str) -> DruidDatasource:
return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
Expand Down Expand Up @@ -340,12 +349,6 @@ def revoke_role_access_to_table(self, role_name, table):
):
security_manager.del_permission_role(public_role, perm)

def _get_database_by_name(self, database_name="main"):
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")

def run_sql(
self,
sql,
Expand All @@ -364,7 +367,7 @@ def run_sql(
if user_name:
self.logout()
self.login(username=(user_name or "admin"))
dbid = self._get_database_by_name(database_name).id
dbid = SupersetTestCase.get_database_by_name(database_name).id
json_payload = {
"database_id": dbid,
"sql": sql,
Expand Down Expand Up @@ -448,7 +451,7 @@ def validate_sql(
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = self._get_database_by_name(database_name).id
dbid = SupersetTestCase.get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/validate_sql_json/",
raise_on_error=False,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def test_update_chart(self):
"""
admin = self.get_user("admin")
gamma = self.get_user("gamma")
birth_names_table_id = SupersetTestCase.get_table_by_name("birth_names").id
birth_names_table_id = SupersetTestCase.get_table(name="birth_names").id
chart_id = self.insert_chart(
"title", [admin.id], birth_names_table_id, admin
).id
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/csv_upload_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_import_csv_explore_database(setup_csv_upload, create_csv_files):
f'CSV file "{CSV_FILENAME1}" uploaded to table "{CSV_UPLOAD_TABLE_W_EXPLORE}"'
in resp
)
table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE_W_EXPLORE)
table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE_W_EXPLORE)
assert table.database_id == utils.get_example_database().id


Expand Down Expand Up @@ -267,7 +267,7 @@ def test_import_csv(setup_csv_upload, create_csv_files):
)
assert success_msg_f2 in resp

table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE)
table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE)
# make sure the new column name is reflected in the table metadata
assert "d" in table.column_names

Expand Down
8 changes: 6 additions & 2 deletions tests/integration_tests/dashboard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def create_table_for_dashboard(
dtype: Dict[str, Any],
table_description: str = "",
fetch_values_predicate: Optional[str] = None,
schema: Optional[str] = None,
) -> SqlaTable:
df.to_sql(
table_name,
Expand All @@ -44,14 +45,17 @@ def create_table_for_dashboard(
dtype=dtype,
index=False,
method="multi",
schema=schema,
)

table_source = ConnectorRegistry.sources["table"]
table = (
db.session.query(table_source).filter_by(table_name=table_name).one_or_none()
db.session.query(table_source)
.filter_by(database_id=database.id, schema=schema, table_name=table_name)
.one_or_none()
)
if not table:
table = table_source(table_name=table_name)
table = table_source(schema=schema, table_name=table_name)
if fetch_values_predicate:
table.fetch_values_predicate = fetch_values_predicate
table.database = database
Expand Down
Loading

0 comments on commit 156fdd8

Please sign in to comment.