From 387549f69c705dc5d2d459ebb5ef1f2a51b1076a Mon Sep 17 00:00:00 2001
From: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
Date: Tue, 15 Aug 2023 16:32:54 -0700
Subject: [PATCH] fix(snowflake): opt-in denormalization of column names
(#24982)
---
UPDATING.md | 1 +
.../Datasource/DatasourceEditor.jsx | 10 +++
.../components/Datasource/DatasourceModal.tsx | 1 +
.../src/features/datasets/types.ts | 1 +
superset/connectors/sqla/models.py | 3 +
superset/connectors/sqla/utils.py | 6 +-
superset/connectors/sqla/views.py | 5 ++
superset/dashboards/schemas.py | 1 +
superset/datasets/api.py | 2 +
superset/datasets/commands/duplicate.py | 1 +
superset/datasets/schemas.py | 4 ++
...676_add_normalize_columns_to_sqla_model.py | 67 +++++++++++++++++++
superset/views/datasource/schemas.py | 6 +-
superset/views/datasource/views.py | 3 +
tests/integration_tests/core_tests.py | 3 +-
tests/integration_tests/dashboard_utils.py | 2 +-
tests/integration_tests/datasets/api_tests.py | 35 +++++++++-
.../datasets/commands_tests.py | 2 +
tests/integration_tests/datasource_tests.py | 5 ++
.../fixtures/importexport.py | 2 +
.../integration_tests/import_export_tests.py | 5 +-
.../datasets/commands/export_test.py | 2 +
22 files changed, 159 insertions(+), 8 deletions(-)
create mode 100644 superset/migrations/versions/2023-08-14_09-38_9f4a086c2676_add_normalize_columns_to_sqla_model.py
diff --git a/UPDATING.md b/UPDATING.md
index 39a1654996943..5a29c43dfad99 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -88,6 +88,7 @@ assists people when migrating to a new version.
### Other
+- [24982](https://github.com/apache/superset/pull/24982): By default, physical datasets on Oracle-like dialects like Snowflake will now use denormalized column names. However, existing datasets won't be affected. To change this behavior, the "Advanced" section on the dataset modal has a "Normalize column names" flag which can be changed to change this behavior.
- [23888](https://github.com/apache/superset/pull/23888): Database Migration for json serialization instead of pickle should upgrade/downgrade correctly when bumping to/from this patch version
## 2.1.0
diff --git a/superset-frontend/src/components/Datasource/DatasourceEditor.jsx b/superset-frontend/src/components/Datasource/DatasourceEditor.jsx
index 5977f44058f66..ead572422c1cd 100644
--- a/superset-frontend/src/components/Datasource/DatasourceEditor.jsx
+++ b/superset-frontend/src/components/Datasource/DatasourceEditor.jsx
@@ -806,6 +806,7 @@ class DatasourceEditor extends React.PureComponent {
table_name: datasource.table_name
? encodeURIComponent(datasource.table_name)
: datasource.table_name,
+ normalize_columns: datasource.normalize_columns,
};
Object.entries(params).forEach(([key, value]) => {
// rison can't encode the undefined value
@@ -1034,6 +1035,15 @@ class DatasourceEditor extends React.PureComponent {
control={}
/>
)}
+ }
+ />
);
}
diff --git a/superset-frontend/src/components/Datasource/DatasourceModal.tsx b/superset-frontend/src/components/Datasource/DatasourceModal.tsx
index 78859f4a2fe47..f9c40c47ba02e 100644
--- a/superset-frontend/src/components/Datasource/DatasourceModal.tsx
+++ b/superset-frontend/src/components/Datasource/DatasourceModal.tsx
@@ -128,6 +128,7 @@ const DatasourceModal: FunctionComponent = ({
schema,
description: currentDatasource.description,
main_dttm_col: currentDatasource.main_dttm_col,
+ normalize_columns: currentDatasource.normalize_columns,
offset: currentDatasource.offset,
default_endpoint: currentDatasource.default_endpoint,
cache_timeout:
diff --git a/superset-frontend/src/features/datasets/types.ts b/superset-frontend/src/features/datasets/types.ts
index 4c2c5b8a95516..9163306267b89 100644
--- a/superset-frontend/src/features/datasets/types.ts
+++ b/superset-frontend/src/features/datasets/types.ts
@@ -63,4 +63,5 @@ export type DatasetObject = {
metrics: MetricObject[];
extra?: string;
is_managed_externally: boolean;
+ normalize_columns: boolean;
};
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index ea5c0c8de0dfe..aeda42cf8cf91 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -546,6 +546,7 @@ class SqlaTable(
is_sqllab_view = Column(Boolean, default=False)
template_params = Column(Text)
extra = Column(Text)
+ normalize_columns = Column(Boolean, default=False)
baselink = "tablemodelview"
@@ -564,6 +565,7 @@ class SqlaTable(
"filter_select_enabled",
"fetch_values_predicate",
"extra",
+ "normalize_columns",
]
update_from_object_fields = [f for f in export_fields if f != "database_id"]
export_parent = "database"
@@ -717,6 +719,7 @@ def external_metadata(self) -> list[ResultSetColumnType]:
database=self.database,
table_name=self.table_name,
schema_name=self.schema,
+ normalize_columns=self.normalize_columns,
)
@property
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 82d8f90f224b4..c8a5f9f260572 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -48,6 +48,7 @@
def get_physical_table_metadata(
database: Database,
table_name: str,
+ normalize_columns: bool,
schema_name: str | None = None,
) -> list[ResultSetColumnType]:
"""Use SQLAlchemy inspector to get table metadata"""
@@ -67,7 +68,10 @@ def get_physical_table_metadata(
for col in cols:
try:
if isinstance(col["type"], TypeEngine):
- name = db_engine_spec.denormalize_name(db_dialect, col["column_name"])
+ name = col["column_name"]
+ if not normalize_columns:
+ name = db_engine_spec.denormalize_name(db_dialect, name)
+
db_type = db_engine_spec.column_datatype_to_string(
col["type"], db_dialect
)
diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py
index 65c6f110e4fe2..f72261eff8964 100644
--- a/superset/connectors/sqla/views.py
+++ b/superset/connectors/sqla/views.py
@@ -313,6 +313,7 @@ class TableModelView( # pylint: disable=too-many-ancestors
"is_sqllab_view",
"template_params",
"extra",
+ "normalize_columns",
]
base_filters = [["id", DatasourceFilter, lambda: []]]
show_columns = edit_columns + ["perm", "slices"]
@@ -379,6 +380,10 @@ class TableModelView( # pylint: disable=too-many-ancestors
'}, "warning_markdown": "This is a warning." }`.',
True,
),
+ "normalize_columns": _(
+ "Allow column names to be changed to case insensitive format, "
+ "if supported (e.g. Oracle, Snowflake)."
+ ),
}
label_columns = {
"slices": _("Associated Charts"),
diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py
index 7905641f80735..f72aeae58296a 100644
--- a/superset/dashboards/schemas.py
+++ b/superset/dashboards/schemas.py
@@ -248,6 +248,7 @@ class DashboardDatasetSchema(Schema):
verbose_map = fields.Dict(fields.Str(), fields.Str())
time_grain_sqla = fields.List(fields.List(fields.Str()))
granularity_sqla = fields.List(fields.List(fields.Str()))
+ normalize_columns = fields.Bool()
class BaseDashboardSchema(Schema):
diff --git a/superset/datasets/api.py b/superset/datasets/api.py
index d5a0478c5d2fa..bb1a4ffd2909f 100644
--- a/superset/datasets/api.py
+++ b/superset/datasets/api.py
@@ -141,6 +141,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"schema",
"description",
"main_dttm_col",
+ "normalize_columns",
"offset",
"default_endpoint",
"cache_timeout",
@@ -218,6 +219,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"schema",
"description",
"main_dttm_col",
+ "normalize_columns",
"offset",
"default_endpoint",
"cache_timeout",
diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py
index dc3ccb85d4b0c..9fc05c0960a5f 100644
--- a/superset/datasets/commands/duplicate.py
+++ b/superset/datasets/commands/duplicate.py
@@ -67,6 +67,7 @@ def run(self) -> Model:
table.database = database
table.schema = self._base_model.schema
table.template_params = self._base_model.template_params
+ table.normalize_columns = self._base_model.normalize_columns
table.is_sqllab_view = True
table.sql = ParsedQuery(self._base_model.sql).stripped()
db.session.add(table)
diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py
index dcac648148ed3..f229604d47259 100644
--- a/superset/datasets/schemas.py
+++ b/superset/datasets/schemas.py
@@ -85,6 +85,7 @@ class DatasetPostSchema(Schema):
owners = fields.List(fields.Integer())
is_managed_externally = fields.Boolean(allow_none=True, dump_default=False)
external_url = fields.String(allow_none=True)
+ normalize_columns = fields.Boolean(load_default=False)
class DatasetPutSchema(Schema):
@@ -96,6 +97,7 @@ class DatasetPutSchema(Schema):
schema = fields.String(allow_none=True, validate=Length(0, 255))
description = fields.String(allow_none=True)
main_dttm_col = fields.String(allow_none=True)
+ normalize_columns = fields.Boolean(allow_none=True, dump_default=False)
offset = fields.Integer(allow_none=True)
default_endpoint = fields.String(allow_none=True)
cache_timeout = fields.Integer(allow_none=True)
@@ -234,6 +236,7 @@ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
data = fields.URL()
is_managed_externally = fields.Boolean(allow_none=True, dump_default=False)
external_url = fields.String(allow_none=True)
+ normalize_columns = fields.Boolean(load_default=False)
class GetOrCreateDatasetSchema(Schema):
@@ -249,6 +252,7 @@ class GetOrCreateDatasetSchema(Schema):
template_params = fields.String(
metadata={"description": "Template params for the table"}
)
+ normalize_columns = fields.Boolean(load_default=False)
class DatasetSchema(SQLAlchemyAutoSchema):
diff --git a/superset/migrations/versions/2023-08-14_09-38_9f4a086c2676_add_normalize_columns_to_sqla_model.py b/superset/migrations/versions/2023-08-14_09-38_9f4a086c2676_add_normalize_columns_to_sqla_model.py
new file mode 100644
index 0000000000000..8eaee8207ce0b
--- /dev/null
+++ b/superset/migrations/versions/2023-08-14_09-38_9f4a086c2676_add_normalize_columns_to_sqla_model.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""add_normalize_columns_to_sqla_model
+
+Revision ID: 9f4a086c2676
+Revises: 4448fa6deeb1
+Create Date: 2023-08-14 09:38:11.897437
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = "9f4a086c2676"
+down_revision = "4448fa6deeb1"
+
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Session
+
+from superset import db
+from superset.migrations.shared.utils import paginated_update
+
+Base = declarative_base()
+
+
+class SqlaTable(Base):
+ __tablename__ = "tables"
+
+ id = sa.Column(sa.Integer, primary_key=True)
+ normalize_columns = sa.Column(sa.Boolean())
+
+
+def upgrade():
+ op.add_column(
+ "tables",
+ sa.Column(
+ "normalize_columns",
+ sa.Boolean(),
+ nullable=True,
+ default=False,
+ server_default=sa.false(),
+ ),
+ )
+
+ bind = op.get_bind()
+ session = db.Session(bind=bind)
+
+ for table in paginated_update(session.query(SqlaTable)):
+ table.normalize_columns = True
+
+
+def downgrade():
+ op.drop_column("tables", "normalize_columns")
diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py
index 5b1700708ad82..0fcdb452ebee4 100644
--- a/superset/views/datasource/schemas.py
+++ b/superset/views/datasource/schemas.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any
+from typing import Any, Optional
from marshmallow import fields, post_load, pre_load, Schema, validate
from typing_extensions import TypedDict
@@ -29,6 +29,7 @@ class ExternalMetadataParams(TypedDict):
database_name: str
schema_name: str
table_name: str
+ normalize_columns: Optional[bool]
get_external_metadata_schema = {
@@ -36,6 +37,7 @@ class ExternalMetadataParams(TypedDict):
"database_name": "string",
"schema_name": "string",
"table_name": "string",
+ "normalize_columns": "boolean",
}
@@ -44,6 +46,7 @@ class ExternalMetadataSchema(Schema):
database_name = fields.Str(required=True)
schema_name = fields.Str(allow_none=True)
table_name = fields.Str(required=True)
+ normalize_columns = fields.Bool(allow_none=True)
# pylint: disable=no-self-use,unused-argument
@post_load
@@ -57,6 +60,7 @@ def normalize(
database_name=data["database_name"],
schema_name=data.get("schema_name", ""),
table_name=data["table_name"],
+ normalize_columns=data["normalize_columns"],
)
diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py
index f1086acd47330..b2fd387379fd9 100644
--- a/superset/views/datasource/views.py
+++ b/superset/views/datasource/views.py
@@ -77,6 +77,8 @@ def save(self) -> FlaskResponse:
return json_error_response(_("Request missing data field."), status=500)
datasource_dict = json.loads(data)
+ normalize_columns = datasource_dict.get("normalize_columns", False)
+ datasource_dict["normalize_columns"] = normalize_columns
datasource_id = datasource_dict.get("id")
datasource_type = datasource_dict.get("type")
database_id = datasource_dict["database"].get("id")
@@ -196,6 +198,7 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse:
database=database,
table_name=params["table_name"],
schema_name=params["schema_name"],
+ normalize_columns=params.get("normalize_columns") or False,
)
except (NoResultFound, NoSuchTableError) as ex:
raise DatasetNotFoundError() from ex
diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py
index e036602d0f973..9b22db7013b76 100644
--- a/tests/integration_tests/core_tests.py
+++ b/tests/integration_tests/core_tests.py
@@ -43,10 +43,9 @@
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
-from superset.exceptions import QueryObjectValidationError, SupersetException
+from superset.exceptions import SupersetException
from superset.extensions import async_query_manager, cache_manager
from superset.models import core as models
-from superset.models.annotations import Annotation, AnnotationLayer
from superset.models.cache import CacheKey
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py
index 6c3d000051f35..e284e21ca4bd9 100644
--- a/tests/integration_tests/dashboard_utils.py
+++ b/tests/integration_tests/dashboard_utils.py
@@ -53,7 +53,7 @@ def create_table_metadata(
table = get_table(table_name, database, schema)
if not table:
- table = SqlaTable(schema=schema, table_name=table_name)
+ table = SqlaTable(schema=schema, table_name=table_name, normalize_columns=False)
if fetch_values_predicate:
table.fetch_values_predicate = fetch_values_predicate
table.database = database
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index 027002507aece..8884b1171aa3f 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -540,6 +540,8 @@ def test_create_dataset_item(self):
model = db.session.query(SqlaTable).get(table_id)
assert model.table_name == table_data["table_name"]
assert model.database_id == table_data["database"]
+ # normalize_columns should default to False
+ assert model.normalize_columns is False
# Assert that columns were created
columns = (
@@ -563,6 +565,34 @@ def test_create_dataset_item(self):
db.session.delete(model)
db.session.commit()
+ def test_create_dataset_item_normalize(self):
+ """
+ Dataset API: Test create dataset item with column normalization enabled
+ """
+ if backend() == "sqlite":
+ return
+
+ main_db = get_main_database()
+ self.login(username="admin")
+ table_data = {
+ "database": main_db.id,
+ "schema": None,
+ "table_name": "ab_permission",
+ "normalize_columns": True,
+ }
+ uri = "api/v1/dataset/"
+ rv = self.post_assert_metric(uri, table_data, "post")
+ assert rv.status_code == 201
+ data = json.loads(rv.data.decode("utf-8"))
+ table_id = data.get("id")
+ model = db.session.query(SqlaTable).get(table_id)
+ assert model.table_name == table_data["table_name"]
+ assert model.database_id == table_data["database"]
+ assert model.normalize_columns is True
+
+ db.session.delete(model)
+ db.session.commit()
+
def test_create_dataset_item_gamma(self):
"""
Dataset API: Test create dataset item gamma
@@ -2494,8 +2524,9 @@ def test_get_or_create_dataset_creates_table(self):
.filter_by(table_name="test_create_sqla_table_api")
.one()
)
- self.assertEqual(response["result"], {"table_id": table.id})
- self.assertEqual(table.template_params, '{"param": 1}')
+ assert response["result"] == {"table_id": table.id}
+ assert table.template_params == '{"param": 1}'
+ assert table.normalize_columns is False
db.session.delete(table)
with examples_db.get_sqla_engine_with_context() as engine:
diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py
index b3b5084e35cb0..19caa9e1a111a 100644
--- a/tests/integration_tests/datasets/commands_tests.py
+++ b/tests/integration_tests/datasets/commands_tests.py
@@ -170,6 +170,7 @@ def test_export_dataset_command(self, mock_g):
"warning_text": None,
},
],
+ "normalize_columns": False,
"offset": 0,
"params": None,
"schema": get_example_default_schema(),
@@ -229,6 +230,7 @@ def test_export_dataset_command_key_order(self, mock_g):
"filter_select_enabled",
"fetch_values_predicate",
"extra",
+ "normalize_columns",
"uuid",
"metrics",
"columns",
diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py
index 5de1cf6ef85e1..4c05898cfe75e 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -105,6 +105,7 @@ def test_external_metadata_by_name_for_physical_table(self):
"database_name": tbl.database.database_name,
"schema_name": tbl.schema,
"table_name": tbl.table_name,
+ "normalize_columns": tbl.normalize_columns,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
@@ -133,6 +134,7 @@ def test_external_metadata_by_name_for_virtual_table(self):
"database_name": tbl.database.database_name,
"schema_name": tbl.schema,
"table_name": tbl.table_name,
+ "normalize_columns": tbl.normalize_columns,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
@@ -151,6 +153,7 @@ def test_external_metadata_by_name_from_sqla_inspector(self):
"database_name": example_database.database_name,
"table_name": "test_table",
"schema_name": get_example_default_schema(),
+ "normalize_columns": False,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
@@ -164,6 +167,7 @@ def test_external_metadata_by_name_from_sqla_inspector(self):
"datasource_type": "table",
"database_name": "foo",
"table_name": "bar",
+ "normalize_columns": False,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
@@ -180,6 +184,7 @@ def test_external_metadata_by_name_from_sqla_inspector(self):
"datasource_type": "table",
"database_name": example_database.database_name,
"table_name": "fooooooooobarrrrrr",
+ "normalize_columns": False,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py
index d0fa04e97dfc7..cda9dd0bcc379 100644
--- a/tests/integration_tests/fixtures/importexport.py
+++ b/tests/integration_tests/fixtures/importexport.py
@@ -312,6 +312,7 @@
"sql": None,
"table_name": "birth_names_2",
"template_params": None,
+ "normalize_columns": False,
}
}
],
@@ -494,6 +495,7 @@
"sql": "",
"params": None,
"template_params": {},
+ "normalize_columns": False,
"filter_select_enabled": True,
"fetch_values_predicate": None,
"extra": '{ "certification": { "certified_by": "Data Platform Team", "details": "This table is the source of truth." }, "warning_markdown": "This is a warning." }',
diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py
index d44745377f562..f3a2a09eef1d2 100644
--- a/tests/integration_tests/import_export_tests.py
+++ b/tests/integration_tests/import_export_tests.py
@@ -122,7 +122,10 @@ def create_dashboard(self, title, id=0, slcs=[]):
def create_table(self, name, schema=None, id=0, cols_names=[], metric_names=[]):
params = {"remote_id": id, "database_name": "examples"}
table = SqlaTable(
- id=id, schema=schema, table_name=name, params=json.dumps(params)
+ id=id,
+ schema=schema,
+ table_name=name,
+ params=json.dumps(params),
)
for col_name in cols_names:
table.columns.append(TableColumn(column_name=col_name))
diff --git a/tests/unit_tests/datasets/commands/export_test.py b/tests/unit_tests/datasets/commands/export_test.py
index 17913c2ca4bd4..e9c217a1916e4 100644
--- a/tests/unit_tests/datasets/commands/export_test.py
+++ b/tests/unit_tests/datasets/commands/export_test.py
@@ -81,6 +81,7 @@ def test_export(session: Session) -> None:
is_sqllab_view=0, # no longer used?
template_params=json.dumps({"answer": "42"}),
schema_perm=None,
+ normalize_columns=False,
extra=json.dumps({"warning_markdown": "*WARNING*"}),
)
@@ -108,6 +109,7 @@ def test_export(session: Session) -> None:
fetch_values_predicate: foo IN (1, 2)
extra:
warning_markdown: '*WARNING*'
+normalize_columns: false
uuid: null
metrics:
- metric_name: cnt