Skip to content

Commit

Permalink
fix(snowflake): opt-in denormalization of column names (#24982)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro authored and michael-s-molina committed Aug 16, 2023
1 parent 1569f01 commit 387549f
Show file tree
Hide file tree
Showing 22 changed files with 159 additions and 8 deletions.
1 change: 1 addition & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ assists people when migrating to a new version.

### Other

- [24982](https://github.com/apache/superset/pull/24982): By default, physical datasets on Oracle-like dialects like Snowflake will now use denormalized column names. However, existing datasets won't be affected. To change this behavior, the "Advanced" section on the dataset modal has a "Normalize column names" flag which can be changed to change this behavior.
- [23888](https://github.com/apache/superset/pull/23888): Database Migration for json serialization instead of pickle should upgrade/downgrade correctly when bumping to/from this patch version

## 2.1.0
Expand Down
10 changes: 10 additions & 0 deletions superset-frontend/src/components/Datasource/DatasourceEditor.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ class DatasourceEditor extends React.PureComponent {
table_name: datasource.table_name
? encodeURIComponent(datasource.table_name)
: datasource.table_name,
normalize_columns: datasource.normalize_columns,
};
Object.entries(params).forEach(([key, value]) => {
// rison can't encode the undefined value
Expand Down Expand Up @@ -1034,6 +1035,15 @@ class DatasourceEditor extends React.PureComponent {
control={<TextControl controlId="template_params" />}
/>
)}
<Field
inline
fieldKey="normalize_columns"
label={t('Normalize column names')}
description={t(
'Allow column names to be changed to case insensitive format, if supported (e.g. Oracle, Snowflake).',
)}
control={<CheckboxControl controlId="normalize_columns" />}
/>
</Fieldset>
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ const DatasourceModal: FunctionComponent<DatasourceModalProps> = ({
schema,
description: currentDatasource.description,
main_dttm_col: currentDatasource.main_dttm_col,
normalize_columns: currentDatasource.normalize_columns,
offset: currentDatasource.offset,
default_endpoint: currentDatasource.default_endpoint,
cache_timeout:
Expand Down
1 change: 1 addition & 0 deletions superset-frontend/src/features/datasets/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,5 @@ export type DatasetObject = {
metrics: MetricObject[];
extra?: string;
is_managed_externally: boolean;
normalize_columns: boolean;
};
3 changes: 3 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ class SqlaTable(
is_sqllab_view = Column(Boolean, default=False)
template_params = Column(Text)
extra = Column(Text)
normalize_columns = Column(Boolean, default=False)

baselink = "tablemodelview"

Expand All @@ -564,6 +565,7 @@ class SqlaTable(
"filter_select_enabled",
"fetch_values_predicate",
"extra",
"normalize_columns",
]
update_from_object_fields = [f for f in export_fields if f != "database_id"]
export_parent = "database"
Expand Down Expand Up @@ -717,6 +719,7 @@ def external_metadata(self) -> list[ResultSetColumnType]:
database=self.database,
table_name=self.table_name,
schema_name=self.schema,
normalize_columns=self.normalize_columns,
)

@property
Expand Down
6 changes: 5 additions & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
def get_physical_table_metadata(
database: Database,
table_name: str,
normalize_columns: bool,
schema_name: str | None = None,
) -> list[ResultSetColumnType]:
"""Use SQLAlchemy inspector to get table metadata"""
Expand All @@ -67,7 +68,10 @@ def get_physical_table_metadata(
for col in cols:
try:
if isinstance(col["type"], TypeEngine):
name = db_engine_spec.denormalize_name(db_dialect, col["column_name"])
name = col["column_name"]
if not normalize_columns:
name = db_engine_spec.denormalize_name(db_dialect, name)

db_type = db_engine_spec.column_datatype_to_string(
col["type"], db_dialect
)
Expand Down
5 changes: 5 additions & 0 deletions superset/connectors/sqla/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ class TableModelView( # pylint: disable=too-many-ancestors
"is_sqllab_view",
"template_params",
"extra",
"normalize_columns",
]
base_filters = [["id", DatasourceFilter, lambda: []]]
show_columns = edit_columns + ["perm", "slices"]
Expand Down Expand Up @@ -379,6 +380,10 @@ class TableModelView( # pylint: disable=too-many-ancestors
'}, "warning_markdown": "This is a warning." }`.',
True,
),
"normalize_columns": _(
"Allow column names to be changed to case insensitive format, "
"if supported (e.g. Oracle, Snowflake)."
),
}
label_columns = {
"slices": _("Associated Charts"),
Expand Down
1 change: 1 addition & 0 deletions superset/dashboards/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ class DashboardDatasetSchema(Schema):
verbose_map = fields.Dict(fields.Str(), fields.Str())
time_grain_sqla = fields.List(fields.List(fields.Str()))
granularity_sqla = fields.List(fields.List(fields.Str()))
normalize_columns = fields.Bool()


class BaseDashboardSchema(Schema):
Expand Down
2 changes: 2 additions & 0 deletions superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"schema",
"description",
"main_dttm_col",
"normalize_columns",
"offset",
"default_endpoint",
"cache_timeout",
Expand Down Expand Up @@ -218,6 +219,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"schema",
"description",
"main_dttm_col",
"normalize_columns",
"offset",
"default_endpoint",
"cache_timeout",
Expand Down
1 change: 1 addition & 0 deletions superset/datasets/commands/duplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def run(self) -> Model:
table.database = database
table.schema = self._base_model.schema
table.template_params = self._base_model.template_params
table.normalize_columns = self._base_model.normalize_columns
table.is_sqllab_view = True
table.sql = ParsedQuery(self._base_model.sql).stripped()
db.session.add(table)
Expand Down
4 changes: 4 additions & 0 deletions superset/datasets/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class DatasetPostSchema(Schema):
owners = fields.List(fields.Integer())
is_managed_externally = fields.Boolean(allow_none=True, dump_default=False)
external_url = fields.String(allow_none=True)
normalize_columns = fields.Boolean(load_default=False)


class DatasetPutSchema(Schema):
Expand All @@ -96,6 +97,7 @@ class DatasetPutSchema(Schema):
schema = fields.String(allow_none=True, validate=Length(0, 255))
description = fields.String(allow_none=True)
main_dttm_col = fields.String(allow_none=True)
normalize_columns = fields.Boolean(allow_none=True, dump_default=False)
offset = fields.Integer(allow_none=True)
default_endpoint = fields.String(allow_none=True)
cache_timeout = fields.Integer(allow_none=True)
Expand Down Expand Up @@ -234,6 +236,7 @@ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
data = fields.URL()
is_managed_externally = fields.Boolean(allow_none=True, dump_default=False)
external_url = fields.String(allow_none=True)
normalize_columns = fields.Boolean(load_default=False)


class GetOrCreateDatasetSchema(Schema):
Expand All @@ -249,6 +252,7 @@ class GetOrCreateDatasetSchema(Schema):
template_params = fields.String(
metadata={"description": "Template params for the table"}
)
normalize_columns = fields.Boolean(load_default=False)


class DatasetSchema(SQLAlchemyAutoSchema):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_normalize_columns_to_sqla_model
Revision ID: 9f4a086c2676
Revises: 4448fa6deeb1
Create Date: 2023-08-14 09:38:11.897437
"""

# revision identifiers, used by Alembic.
revision = "9f4a086c2676"
down_revision = "4448fa6deeb1"

import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

from superset import db
from superset.migrations.shared.utils import paginated_update

Base = declarative_base()


class SqlaTable(Base):
__tablename__ = "tables"

id = sa.Column(sa.Integer, primary_key=True)
normalize_columns = sa.Column(sa.Boolean())


def upgrade():
op.add_column(
"tables",
sa.Column(
"normalize_columns",
sa.Boolean(),
nullable=True,
default=False,
server_default=sa.false(),
),
)

bind = op.get_bind()
session = db.Session(bind=bind)

for table in paginated_update(session.query(SqlaTable)):
table.normalize_columns = True


def downgrade():
op.drop_column("tables", "normalize_columns")
6 changes: 5 additions & 1 deletion superset/views/datasource/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from typing import Any, Optional

from marshmallow import fields, post_load, pre_load, Schema, validate
from typing_extensions import TypedDict
Expand All @@ -29,13 +29,15 @@ class ExternalMetadataParams(TypedDict):
database_name: str
schema_name: str
table_name: str
normalize_columns: Optional[bool]


get_external_metadata_schema = {
"datasource_type": "string",
"database_name": "string",
"schema_name": "string",
"table_name": "string",
"normalize_columns": "boolean",
}


Expand All @@ -44,6 +46,7 @@ class ExternalMetadataSchema(Schema):
database_name = fields.Str(required=True)
schema_name = fields.Str(allow_none=True)
table_name = fields.Str(required=True)
normalize_columns = fields.Bool(allow_none=True)

# pylint: disable=no-self-use,unused-argument
@post_load
Expand All @@ -57,6 +60,7 @@ def normalize(
database_name=data["database_name"],
schema_name=data.get("schema_name", ""),
table_name=data["table_name"],
normalize_columns=data["normalize_columns"],
)


Expand Down
3 changes: 3 additions & 0 deletions superset/views/datasource/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def save(self) -> FlaskResponse:
return json_error_response(_("Request missing data field."), status=500)

datasource_dict = json.loads(data)
normalize_columns = datasource_dict.get("normalize_columns", False)
datasource_dict["normalize_columns"] = normalize_columns
datasource_id = datasource_dict.get("id")
datasource_type = datasource_dict.get("type")
database_id = datasource_dict["database"].get("id")
Expand Down Expand Up @@ -196,6 +198,7 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse:
database=database,
table_name=params["table_name"],
schema_name=params["schema_name"],
normalize_columns=params.get("normalize_columns") or False,
)
except (NoResultFound, NoSuchTableError) as ex:
raise DatasetNotFoundError() from ex
Expand Down
3 changes: 1 addition & 2 deletions tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.exceptions import QueryObjectValidationError, SupersetException
from superset.exceptions import SupersetException
from superset.extensions import async_query_manager, cache_manager
from superset.models import core as models
from superset.models.annotations import Annotation, AnnotationLayer
from superset.models.cache import CacheKey
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/dashboard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create_table_metadata(

table = get_table(table_name, database, schema)
if not table:
table = SqlaTable(schema=schema, table_name=table_name)
table = SqlaTable(schema=schema, table_name=table_name, normalize_columns=False)
if fetch_values_predicate:
table.fetch_values_predicate = fetch_values_predicate
table.database = database
Expand Down
35 changes: 33 additions & 2 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,8 @@ def test_create_dataset_item(self):
model = db.session.query(SqlaTable).get(table_id)
assert model.table_name == table_data["table_name"]
assert model.database_id == table_data["database"]
# normalize_columns should default to False
assert model.normalize_columns is False

# Assert that columns were created
columns = (
Expand All @@ -563,6 +565,34 @@ def test_create_dataset_item(self):
db.session.delete(model)
db.session.commit()

def test_create_dataset_item_normalize(self):
"""
Dataset API: Test create dataset item with column normalization enabled
"""
if backend() == "sqlite":
return

main_db = get_main_database()
self.login(username="admin")
table_data = {
"database": main_db.id,
"schema": None,
"table_name": "ab_permission",
"normalize_columns": True,
}
uri = "api/v1/dataset/"
rv = self.post_assert_metric(uri, table_data, "post")
assert rv.status_code == 201
data = json.loads(rv.data.decode("utf-8"))
table_id = data.get("id")
model = db.session.query(SqlaTable).get(table_id)
assert model.table_name == table_data["table_name"]
assert model.database_id == table_data["database"]
assert model.normalize_columns is True

db.session.delete(model)
db.session.commit()

def test_create_dataset_item_gamma(self):
"""
Dataset API: Test create dataset item gamma
Expand Down Expand Up @@ -2494,8 +2524,9 @@ def test_get_or_create_dataset_creates_table(self):
.filter_by(table_name="test_create_sqla_table_api")
.one()
)
self.assertEqual(response["result"], {"table_id": table.id})
self.assertEqual(table.template_params, '{"param": 1}')
assert response["result"] == {"table_id": table.id}
assert table.template_params == '{"param": 1}'
assert table.normalize_columns is False

db.session.delete(table)
with examples_db.get_sqla_engine_with_context() as engine:
Expand Down
2 changes: 2 additions & 0 deletions tests/integration_tests/datasets/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def test_export_dataset_command(self, mock_g):
"warning_text": None,
},
],
"normalize_columns": False,
"offset": 0,
"params": None,
"schema": get_example_default_schema(),
Expand Down Expand Up @@ -229,6 +230,7 @@ def test_export_dataset_command_key_order(self, mock_g):
"filter_select_enabled",
"fetch_values_predicate",
"extra",
"normalize_columns",
"uuid",
"metrics",
"columns",
Expand Down
Loading

0 comments on commit 387549f

Please sign in to comment.