From f855a5e20b58f57916d4b82e6e753775a2197c0c Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Mon, 13 May 2024 11:55:59 -0700 Subject: [PATCH] fix: Update migration logic in #27119 (#28422) --- .../src/operators/types.ts | 2 +- superset/migrations/shared/utils.py | 33 ++++++++++--- ..._17fcea065655_change_text_to_mediumtext.py | 48 ++++++++++++------- superset/models/sql_lab.py | 14 ++++-- superset/utils/core.py | 6 ++- .../pandas_postprocessing/contribution.py | 16 ++++--- 6 files changed, 81 insertions(+), 38 deletions(-) diff --git a/superset-frontend/packages/superset-ui-chart-controls/src/operators/types.ts b/superset-frontend/packages/superset-ui-chart-controls/src/operators/types.ts index 34f632ff8f38f..0c5285a2a1e8a 100644 --- a/superset-frontend/packages/superset-ui-chart-controls/src/operators/types.ts +++ b/superset-frontend/packages/superset-ui-chart-controls/src/operators/types.ts @@ -19,5 +19,5 @@ import { QueryFormData, QueryObject } from '@superset-ui/core'; export interface PostProcessingFactory { - (formData: QueryFormData, queryObject: QueryObject): T; + (formData: QueryFormData, queryObject: QueryObject, options?: any): T; } diff --git a/superset/migrations/shared/utils.py b/superset/migrations/shared/utils.py index 2ae0dfeac158a..d6a664f330f28 100644 --- a/superset/migrations/shared/utils.py +++ b/superset/migrations/shared/utils.py @@ -35,21 +35,40 @@ DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000)) -def table_has_column(table: str, column: str) -> bool: +def get_table_column( + table_name: str, + column_name: str, +) -> Optional[list[dict[str, Any]]]: """ - Checks if a column exists in a given table. + Get the specified column. - :param table: A table name - :param column: A column name - :returns: True iff the column exists in the table + :param table_name: The Table name + :param column_name: The column name + :returns: The column """ insp = inspect(op.get_context().bind) try: - return any(col["name"] == column for col in insp.get_columns(table)) + for column in insp.get_columns(table_name): + if column["name"] == column_name: + return column except NoSuchTableError: - return False + pass + + return None + + +def table_has_column(table_name: str, column_name: str) -> bool: + """ + Checks if a column exists in a given table. + + :param table_name: A table name + :param column_name: A column name + :returns: True iff the column exists in the table + """ + + return bool(get_table_column(table_name, column_name)) uuid_by_dialect = { diff --git a/superset/migrations/versions/2024-02-14_14-43_17fcea065655_change_text_to_mediumtext.py b/superset/migrations/versions/2024-02-14_14-43_17fcea065655_change_text_to_mediumtext.py index e63ab6ac5644a..1f4474eeed9fb 100644 --- a/superset/migrations/versions/2024-02-14_14-43_17fcea065655_change_text_to_mediumtext.py +++ b/superset/migrations/versions/2024-02-14_14-43_17fcea065655_change_text_to_mediumtext.py @@ -28,8 +28,10 @@ import sqlalchemy as sa from alembic import op +from sqlalchemy.dialects.mysql import MEDIUMTEXT, TEXT from sqlalchemy.dialects.mysql.base import MySQLDialect +from superset.migrations.shared.utils import get_table_column from superset.utils.core import MediumText TABLE_COLUMNS = [ @@ -38,8 +40,6 @@ "dashboards.css", "keyvalue.value", "query.extra_json", - "query.executed_sql", - "query.select_sql", "report_execution_log.value_row_json", "report_recipient.recipient_config_json", "report_schedule.sql", @@ -65,23 +65,35 @@ def upgrade(): if isinstance(op.get_bind().dialect, MySQLDialect): - for column in TABLE_COLUMNS: - with op.batch_alter_table(column.split(".")[0]) as batch_op: - batch_op.alter_column( - column.split(".")[1], - existing_type=sa.Text(), - type_=MediumText(), - existing_nullable=column not in NOT_NULL_COLUMNS, - ) + for item in TABLE_COLUMNS: + table_name, column_name = item.split(".") + + if (column := get_table_column(table_name, column_name)) and isinstance( + column["type"], + TEXT, + ): + with op.batch_alter_table(table_name) as batch_op: + batch_op.alter_column( + column_name, + existing_type=sa.Text(), + type_=MediumText(), + existing_nullable=item not in NOT_NULL_COLUMNS, + ) def downgrade(): if isinstance(op.get_bind().dialect, MySQLDialect): - for column in TABLE_COLUMNS: - with op.batch_alter_table(column.split(".")[0]) as batch_op: - batch_op.alter_column( - column.split(".")[1], - existing_type=MediumText(), - type_=sa.Text(), - existing_nullable=column not in NOT_NULL_COLUMNS, - ) + for item in TABLE_COLUMNS: + table_name, column_name = item.split(".") + + if (column := get_table_column(table_name, column_name)) and isinstance( + column["type"], + MEDIUMTEXT, + ): + with op.batch_alter_table(table_name) as batch_op: + batch_op.alter_column( + column_name, + existing_type=MediumText(), + type_=sa.Text(), + existing_nullable=item not in NOT_NULL_COLUMNS, + ) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 25c21cdfc883c..78c29bd2bbc50 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -57,7 +57,13 @@ ) from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table from superset.sqllab.limiting_factor import LimitingFactor -from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label +from superset.utils.core import ( + get_column_name, + LongText, + MediumText, + QueryStatus, + user_label, +) if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn @@ -107,11 +113,11 @@ class Query( tab_name = Column(String(256)) sql_editor_id = Column(String(256)) schema = Column(String(256)) - sql = Column(MediumText()) + sql = Column(LongText()) # Query to retrieve the results, # used only in case of select_as_cta_used is true. - select_sql = Column(MediumText()) - executed_sql = Column(MediumText()) + select_sql = Column(LongText()) + executed_sql = Column(LongText()) # Could be configured in the superset config. limit = Column(Integer) limiting_factor = Column( diff --git a/superset/utils/core.py b/superset/utils/core.py index f605ef99c1b4d..6649f3471799b 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -68,7 +68,7 @@ from pandas.api.types import infer_dtype from pandas.core.dtypes.common import is_numeric_dtype from sqlalchemy import event, exc, inspect, select, Text -from sqlalchemy.dialects.mysql import MEDIUMTEXT +from sqlalchemy.dialects.mysql import LONGTEXT, MEDIUMTEXT from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql.type_api import Variant @@ -1469,6 +1469,10 @@ def MediumText() -> Variant: # pylint:disable=invalid-name return Text().with_variant(MEDIUMTEXT(), "mysql") +def LongText() -> Variant: # pylint:disable=invalid-name + return Text().with_variant(LONGTEXT(), "mysql") + + def shortid() -> str: return f"{uuid.uuid4()}"[-12:] diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py index 46144ec019402..e0deb08322542 100644 --- a/superset/utils/pandas_postprocessing/contribution.py +++ b/superset/utils/pandas_postprocessing/contribution.py @@ -74,7 +74,8 @@ def contribution( if len(rename_columns) != len(actual_columns): raise InvalidPostProcessingError( _( - "`rename_columns` must have the same length as `columns` + `time_shift_columns`." + "`rename_columns` must have the same length as " + + "`columns` + `time_shift_columns`." ) ) # limit to selected columns @@ -105,10 +106,10 @@ def get_column_groups( :param df: DataFrame to group columns from :param time_shifts: List of time shifts to group by :param rename_columns: List of new column names - :return: Dictionary with two keys: 'non_time_shift' and 'time_shifts'. 'non_time_shift' - maps to a tuple of original and renamed columns without a time shift. 'time_shifts' maps - to a dictionary where each key is a time shift and each value is a tuple of original and - renamed columns with that time shift. + :return: Dictionary with two keys: 'non_time_shift' and 'time_shifts'. + 'non_time_shift' maps to a tuple of original and renamed columns without a time shift. + 'time_shifts' maps to a dictionary where each key is a time shift and each value is a + tuple of original and renamed columns with that time shift. """ result: dict[str, Any] = { "non_time_shift": ([], []), # take the form of ([A, B, C], [X, Y, Z]) @@ -139,8 +140,9 @@ def calculate_row_contribution( """ Calculate the contribution of each column to the row total and update the DataFrame. - This function calculates the contribution of each selected column to the total of the row, - and updates the DataFrame with these contribution percentages in place of the original values. + This function calculates the contribution of each selected column to the total of + the row, and updates the DataFrame with these contribution percentages in place of + the original values. :param df: The DataFrame to calculate contributions for. :param columns: A list of column names to calculate contributions for.