From 387549f69c705dc5d2d459ebb5ef1f2a51b1076a Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Tue, 15 Aug 2023 16:32:54 -0700 Subject: [PATCH] fix(snowflake): opt-in denormalization of column names (#24982) --- UPDATING.md | 1 + .../Datasource/DatasourceEditor.jsx | 10 +++ .../components/Datasource/DatasourceModal.tsx | 1 + .../src/features/datasets/types.ts | 1 + superset/connectors/sqla/models.py | 3 + superset/connectors/sqla/utils.py | 6 +- superset/connectors/sqla/views.py | 5 ++ superset/dashboards/schemas.py | 1 + superset/datasets/api.py | 2 + superset/datasets/commands/duplicate.py | 1 + superset/datasets/schemas.py | 4 ++ ...676_add_normalize_columns_to_sqla_model.py | 67 +++++++++++++++++++ superset/views/datasource/schemas.py | 6 +- superset/views/datasource/views.py | 3 + tests/integration_tests/core_tests.py | 3 +- tests/integration_tests/dashboard_utils.py | 2 +- tests/integration_tests/datasets/api_tests.py | 35 +++++++++- .../datasets/commands_tests.py | 2 + tests/integration_tests/datasource_tests.py | 5 ++ .../fixtures/importexport.py | 2 + .../integration_tests/import_export_tests.py | 5 +- .../datasets/commands/export_test.py | 2 + 22 files changed, 159 insertions(+), 8 deletions(-) create mode 100644 superset/migrations/versions/2023-08-14_09-38_9f4a086c2676_add_normalize_columns_to_sqla_model.py diff --git a/UPDATING.md b/UPDATING.md index 39a1654996943..5a29c43dfad99 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -88,6 +88,7 @@ assists people when migrating to a new version. ### Other +- [24982](https://github.com/apache/superset/pull/24982): By default, physical datasets on Oracle-like dialects like Snowflake will now use denormalized column names. However, existing datasets won't be affected. To change this behavior, the "Advanced" section on the dataset modal has a "Normalize column names" flag which can be changed to change this behavior. - [23888](https://github.com/apache/superset/pull/23888): Database Migration for json serialization instead of pickle should upgrade/downgrade correctly when bumping to/from this patch version ## 2.1.0 diff --git a/superset-frontend/src/components/Datasource/DatasourceEditor.jsx b/superset-frontend/src/components/Datasource/DatasourceEditor.jsx index 5977f44058f66..ead572422c1cd 100644 --- a/superset-frontend/src/components/Datasource/DatasourceEditor.jsx +++ b/superset-frontend/src/components/Datasource/DatasourceEditor.jsx @@ -806,6 +806,7 @@ class DatasourceEditor extends React.PureComponent { table_name: datasource.table_name ? encodeURIComponent(datasource.table_name) : datasource.table_name, + normalize_columns: datasource.normalize_columns, }; Object.entries(params).forEach(([key, value]) => { // rison can't encode the undefined value @@ -1034,6 +1035,15 @@ class DatasourceEditor extends React.PureComponent { control={} /> )} + } + /> ); } diff --git a/superset-frontend/src/components/Datasource/DatasourceModal.tsx b/superset-frontend/src/components/Datasource/DatasourceModal.tsx index 78859f4a2fe47..f9c40c47ba02e 100644 --- a/superset-frontend/src/components/Datasource/DatasourceModal.tsx +++ b/superset-frontend/src/components/Datasource/DatasourceModal.tsx @@ -128,6 +128,7 @@ const DatasourceModal: FunctionComponent = ({ schema, description: currentDatasource.description, main_dttm_col: currentDatasource.main_dttm_col, + normalize_columns: currentDatasource.normalize_columns, offset: currentDatasource.offset, default_endpoint: currentDatasource.default_endpoint, cache_timeout: diff --git a/superset-frontend/src/features/datasets/types.ts b/superset-frontend/src/features/datasets/types.ts index 4c2c5b8a95516..9163306267b89 100644 --- a/superset-frontend/src/features/datasets/types.ts +++ b/superset-frontend/src/features/datasets/types.ts @@ -63,4 +63,5 @@ export type DatasetObject = { metrics: MetricObject[]; extra?: string; is_managed_externally: boolean; + normalize_columns: boolean; }; diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ea5c0c8de0dfe..aeda42cf8cf91 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -546,6 +546,7 @@ class SqlaTable( is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) extra = Column(Text) + normalize_columns = Column(Boolean, default=False) baselink = "tablemodelview" @@ -564,6 +565,7 @@ class SqlaTable( "filter_select_enabled", "fetch_values_predicate", "extra", + "normalize_columns", ] update_from_object_fields = [f for f in export_fields if f != "database_id"] export_parent = "database" @@ -717,6 +719,7 @@ def external_metadata(self) -> list[ResultSetColumnType]: database=self.database, table_name=self.table_name, schema_name=self.schema, + normalize_columns=self.normalize_columns, ) @property diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 82d8f90f224b4..c8a5f9f260572 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -48,6 +48,7 @@ def get_physical_table_metadata( database: Database, table_name: str, + normalize_columns: bool, schema_name: str | None = None, ) -> list[ResultSetColumnType]: """Use SQLAlchemy inspector to get table metadata""" @@ -67,7 +68,10 @@ def get_physical_table_metadata( for col in cols: try: if isinstance(col["type"], TypeEngine): - name = db_engine_spec.denormalize_name(db_dialect, col["column_name"]) + name = col["column_name"] + if not normalize_columns: + name = db_engine_spec.denormalize_name(db_dialect, name) + db_type = db_engine_spec.column_datatype_to_string( col["type"], db_dialect ) diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 65c6f110e4fe2..f72261eff8964 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -313,6 +313,7 @@ class TableModelView( # pylint: disable=too-many-ancestors "is_sqllab_view", "template_params", "extra", + "normalize_columns", ] base_filters = [["id", DatasourceFilter, lambda: []]] show_columns = edit_columns + ["perm", "slices"] @@ -379,6 +380,10 @@ class TableModelView( # pylint: disable=too-many-ancestors '}, "warning_markdown": "This is a warning." }`.', True, ), + "normalize_columns": _( + "Allow column names to be changed to case insensitive format, " + "if supported (e.g. Oracle, Snowflake)." + ), } label_columns = { "slices": _("Associated Charts"), diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index 7905641f80735..f72aeae58296a 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -248,6 +248,7 @@ class DashboardDatasetSchema(Schema): verbose_map = fields.Dict(fields.Str(), fields.Str()) time_grain_sqla = fields.List(fields.List(fields.Str())) granularity_sqla = fields.List(fields.List(fields.Str())) + normalize_columns = fields.Bool() class BaseDashboardSchema(Schema): diff --git a/superset/datasets/api.py b/superset/datasets/api.py index d5a0478c5d2fa..bb1a4ffd2909f 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -141,6 +141,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): "schema", "description", "main_dttm_col", + "normalize_columns", "offset", "default_endpoint", "cache_timeout", @@ -218,6 +219,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): "schema", "description", "main_dttm_col", + "normalize_columns", "offset", "default_endpoint", "cache_timeout", diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py index dc3ccb85d4b0c..9fc05c0960a5f 100644 --- a/superset/datasets/commands/duplicate.py +++ b/superset/datasets/commands/duplicate.py @@ -67,6 +67,7 @@ def run(self) -> Model: table.database = database table.schema = self._base_model.schema table.template_params = self._base_model.template_params + table.normalize_columns = self._base_model.normalize_columns table.is_sqllab_view = True table.sql = ParsedQuery(self._base_model.sql).stripped() db.session.add(table) diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index dcac648148ed3..f229604d47259 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -85,6 +85,7 @@ class DatasetPostSchema(Schema): owners = fields.List(fields.Integer()) is_managed_externally = fields.Boolean(allow_none=True, dump_default=False) external_url = fields.String(allow_none=True) + normalize_columns = fields.Boolean(load_default=False) class DatasetPutSchema(Schema): @@ -96,6 +97,7 @@ class DatasetPutSchema(Schema): schema = fields.String(allow_none=True, validate=Length(0, 255)) description = fields.String(allow_none=True) main_dttm_col = fields.String(allow_none=True) + normalize_columns = fields.Boolean(allow_none=True, dump_default=False) offset = fields.Integer(allow_none=True) default_endpoint = fields.String(allow_none=True) cache_timeout = fields.Integer(allow_none=True) @@ -234,6 +236,7 @@ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: data = fields.URL() is_managed_externally = fields.Boolean(allow_none=True, dump_default=False) external_url = fields.String(allow_none=True) + normalize_columns = fields.Boolean(load_default=False) class GetOrCreateDatasetSchema(Schema): @@ -249,6 +252,7 @@ class GetOrCreateDatasetSchema(Schema): template_params = fields.String( metadata={"description": "Template params for the table"} ) + normalize_columns = fields.Boolean(load_default=False) class DatasetSchema(SQLAlchemyAutoSchema): diff --git a/superset/migrations/versions/2023-08-14_09-38_9f4a086c2676_add_normalize_columns_to_sqla_model.py b/superset/migrations/versions/2023-08-14_09-38_9f4a086c2676_add_normalize_columns_to_sqla_model.py new file mode 100644 index 0000000000000..8eaee8207ce0b --- /dev/null +++ b/superset/migrations/versions/2023-08-14_09-38_9f4a086c2676_add_normalize_columns_to_sqla_model.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add_normalize_columns_to_sqla_model + +Revision ID: 9f4a086c2676 +Revises: 4448fa6deeb1 +Create Date: 2023-08-14 09:38:11.897437 + +""" + +# revision identifiers, used by Alembic. +revision = "9f4a086c2676" +down_revision = "4448fa6deeb1" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session + +from superset import db +from superset.migrations.shared.utils import paginated_update + +Base = declarative_base() + + +class SqlaTable(Base): + __tablename__ = "tables" + + id = sa.Column(sa.Integer, primary_key=True) + normalize_columns = sa.Column(sa.Boolean()) + + +def upgrade(): + op.add_column( + "tables", + sa.Column( + "normalize_columns", + sa.Boolean(), + nullable=True, + default=False, + server_default=sa.false(), + ), + ) + + bind = op.get_bind() + session = db.Session(bind=bind) + + for table in paginated_update(session.query(SqlaTable)): + table.normalize_columns = True + + +def downgrade(): + op.drop_column("tables", "normalize_columns") diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index 5b1700708ad82..0fcdb452ebee4 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any +from typing import Any, Optional from marshmallow import fields, post_load, pre_load, Schema, validate from typing_extensions import TypedDict @@ -29,6 +29,7 @@ class ExternalMetadataParams(TypedDict): database_name: str schema_name: str table_name: str + normalize_columns: Optional[bool] get_external_metadata_schema = { @@ -36,6 +37,7 @@ class ExternalMetadataParams(TypedDict): "database_name": "string", "schema_name": "string", "table_name": "string", + "normalize_columns": "boolean", } @@ -44,6 +46,7 @@ class ExternalMetadataSchema(Schema): database_name = fields.Str(required=True) schema_name = fields.Str(allow_none=True) table_name = fields.Str(required=True) + normalize_columns = fields.Bool(allow_none=True) # pylint: disable=no-self-use,unused-argument @post_load @@ -57,6 +60,7 @@ def normalize( database_name=data["database_name"], schema_name=data.get("schema_name", ""), table_name=data["table_name"], + normalize_columns=data["normalize_columns"], ) diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index f1086acd47330..b2fd387379fd9 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -77,6 +77,8 @@ def save(self) -> FlaskResponse: return json_error_response(_("Request missing data field."), status=500) datasource_dict = json.loads(data) + normalize_columns = datasource_dict.get("normalize_columns", False) + datasource_dict["normalize_columns"] = normalize_columns datasource_id = datasource_dict.get("id") datasource_type = datasource_dict.get("type") database_id = datasource_dict["database"].get("id") @@ -196,6 +198,7 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse: database=database, table_name=params["table_name"], schema_name=params["schema_name"], + normalize_columns=params.get("normalize_columns") or False, ) except (NoResultFound, NoSuchTableError) as ex: raise DatasetNotFoundError() from ex diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index e036602d0f973..9b22db7013b76 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -43,10 +43,9 @@ from superset.connectors.sqla.models import SqlaTable from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec -from superset.exceptions import QueryObjectValidationError, SupersetException +from superset.exceptions import SupersetException from superset.extensions import async_query_manager, cache_manager from superset.models import core as models -from superset.models.annotations import Annotation, AnnotationLayer from superset.models.cache import CacheKey from superset.models.dashboard import Dashboard from superset.models.slice import Slice diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index 6c3d000051f35..e284e21ca4bd9 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -53,7 +53,7 @@ def create_table_metadata( table = get_table(table_name, database, schema) if not table: - table = SqlaTable(schema=schema, table_name=table_name) + table = SqlaTable(schema=schema, table_name=table_name, normalize_columns=False) if fetch_values_predicate: table.fetch_values_predicate = fetch_values_predicate table.database = database diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 027002507aece..8884b1171aa3f 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -540,6 +540,8 @@ def test_create_dataset_item(self): model = db.session.query(SqlaTable).get(table_id) assert model.table_name == table_data["table_name"] assert model.database_id == table_data["database"] + # normalize_columns should default to False + assert model.normalize_columns is False # Assert that columns were created columns = ( @@ -563,6 +565,34 @@ def test_create_dataset_item(self): db.session.delete(model) db.session.commit() + def test_create_dataset_item_normalize(self): + """ + Dataset API: Test create dataset item with column normalization enabled + """ + if backend() == "sqlite": + return + + main_db = get_main_database() + self.login(username="admin") + table_data = { + "database": main_db.id, + "schema": None, + "table_name": "ab_permission", + "normalize_columns": True, + } + uri = "api/v1/dataset/" + rv = self.post_assert_metric(uri, table_data, "post") + assert rv.status_code == 201 + data = json.loads(rv.data.decode("utf-8")) + table_id = data.get("id") + model = db.session.query(SqlaTable).get(table_id) + assert model.table_name == table_data["table_name"] + assert model.database_id == table_data["database"] + assert model.normalize_columns is True + + db.session.delete(model) + db.session.commit() + def test_create_dataset_item_gamma(self): """ Dataset API: Test create dataset item gamma @@ -2494,8 +2524,9 @@ def test_get_or_create_dataset_creates_table(self): .filter_by(table_name="test_create_sqla_table_api") .one() ) - self.assertEqual(response["result"], {"table_id": table.id}) - self.assertEqual(table.template_params, '{"param": 1}') + assert response["result"] == {"table_id": table.id} + assert table.template_params == '{"param": 1}' + assert table.normalize_columns is False db.session.delete(table) with examples_db.get_sqla_engine_with_context() as engine: diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index b3b5084e35cb0..19caa9e1a111a 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -170,6 +170,7 @@ def test_export_dataset_command(self, mock_g): "warning_text": None, }, ], + "normalize_columns": False, "offset": 0, "params": None, "schema": get_example_default_schema(), @@ -229,6 +230,7 @@ def test_export_dataset_command_key_order(self, mock_g): "filter_select_enabled", "fetch_values_predicate", "extra", + "normalize_columns", "uuid", "metrics", "columns", diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 5de1cf6ef85e1..4c05898cfe75e 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -105,6 +105,7 @@ def test_external_metadata_by_name_for_physical_table(self): "database_name": tbl.database.database_name, "schema_name": tbl.schema, "table_name": tbl.table_name, + "normalize_columns": tbl.normalize_columns, } ) url = f"/datasource/external_metadata_by_name/?q={params}" @@ -133,6 +134,7 @@ def test_external_metadata_by_name_for_virtual_table(self): "database_name": tbl.database.database_name, "schema_name": tbl.schema, "table_name": tbl.table_name, + "normalize_columns": tbl.normalize_columns, } ) url = f"/datasource/external_metadata_by_name/?q={params}" @@ -151,6 +153,7 @@ def test_external_metadata_by_name_from_sqla_inspector(self): "database_name": example_database.database_name, "table_name": "test_table", "schema_name": get_example_default_schema(), + "normalize_columns": False, } ) url = f"/datasource/external_metadata_by_name/?q={params}" @@ -164,6 +167,7 @@ def test_external_metadata_by_name_from_sqla_inspector(self): "datasource_type": "table", "database_name": "foo", "table_name": "bar", + "normalize_columns": False, } ) url = f"/datasource/external_metadata_by_name/?q={params}" @@ -180,6 +184,7 @@ def test_external_metadata_by_name_from_sqla_inspector(self): "datasource_type": "table", "database_name": example_database.database_name, "table_name": "fooooooooobarrrrrr", + "normalize_columns": False, } ) url = f"/datasource/external_metadata_by_name/?q={params}" diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py index d0fa04e97dfc7..cda9dd0bcc379 100644 --- a/tests/integration_tests/fixtures/importexport.py +++ b/tests/integration_tests/fixtures/importexport.py @@ -312,6 +312,7 @@ "sql": None, "table_name": "birth_names_2", "template_params": None, + "normalize_columns": False, } } ], @@ -494,6 +495,7 @@ "sql": "", "params": None, "template_params": {}, + "normalize_columns": False, "filter_select_enabled": True, "fetch_values_predicate": None, "extra": '{ "certification": { "certified_by": "Data Platform Team", "details": "This table is the source of truth." }, "warning_markdown": "This is a warning." }', diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index d44745377f562..f3a2a09eef1d2 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -122,7 +122,10 @@ def create_dashboard(self, title, id=0, slcs=[]): def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]): params = {"remote_id": id, "database_name": "examples"} table = SqlaTable( - id=id, schema=schema, table_name=name, params=json.dumps(params) + id=id, + schema=schema, + table_name=name, + params=json.dumps(params), ) for col_name in cols_names: table.columns.append(TableColumn(column_name=col_name)) diff --git a/tests/unit_tests/datasets/commands/export_test.py b/tests/unit_tests/datasets/commands/export_test.py index 17913c2ca4bd4..e9c217a1916e4 100644 --- a/tests/unit_tests/datasets/commands/export_test.py +++ b/tests/unit_tests/datasets/commands/export_test.py @@ -81,6 +81,7 @@ def test_export(session: Session) -> None: is_sqllab_view=0, # no longer used? template_params=json.dumps({"answer": "42"}), schema_perm=None, + normalize_columns=False, extra=json.dumps({"warning_markdown": "*WARNING*"}), ) @@ -108,6 +109,7 @@ def test_export(session: Session) -> None: fetch_values_predicate: foo IN (1, 2) extra: warning_markdown: '*WARNING*' +normalize_columns: false uuid: null metrics: - metric_name: cnt