Skip to content

Commit

Permalink
fix: Update migration logic in apache#27119 (apache#28422)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored and michael-s-molina committed May 13, 2024
1 parent 2c7982c commit f855a5e
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
import { QueryFormData, QueryObject } from '@superset-ui/core';

export interface PostProcessingFactory<T> {
(formData: QueryFormData, queryObject: QueryObject): T;
(formData: QueryFormData, queryObject: QueryObject, options?: any): T;
}
33 changes: 26 additions & 7 deletions superset/migrations/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,40 @@
DEFAULT_BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 1000))


def table_has_column(table: str, column: str) -> bool:
def get_table_column(
table_name: str,
column_name: str,
) -> Optional[list[dict[str, Any]]]:
"""
Checks if a column exists in a given table.
Get the specified column.
:param table: A table name
:param column: A column name
:returns: True iff the column exists in the table
:param table_name: The Table name
:param column_name: The column name
:returns: The column
"""

insp = inspect(op.get_context().bind)

try:
return any(col["name"] == column for col in insp.get_columns(table))
for column in insp.get_columns(table_name):
if column["name"] == column_name:
return column
except NoSuchTableError:
return False
pass

return None


def table_has_column(table_name: str, column_name: str) -> bool:
"""
Checks if a column exists in a given table.
:param table_name: A table name
:param column_name: A column name
:returns: True iff the column exists in the table
"""

return bool(get_table_column(table_name, column_name))


uuid_by_dialect = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.mysql import MEDIUMTEXT, TEXT
from sqlalchemy.dialects.mysql.base import MySQLDialect

from superset.migrations.shared.utils import get_table_column
from superset.utils.core import MediumText

TABLE_COLUMNS = [
Expand All @@ -38,8 +40,6 @@
"dashboards.css",
"keyvalue.value",
"query.extra_json",
"query.executed_sql",
"query.select_sql",
"report_execution_log.value_row_json",
"report_recipient.recipient_config_json",
"report_schedule.sql",
Expand All @@ -65,23 +65,35 @@

def upgrade():
if isinstance(op.get_bind().dialect, MySQLDialect):
for column in TABLE_COLUMNS:
with op.batch_alter_table(column.split(".")[0]) as batch_op:
batch_op.alter_column(
column.split(".")[1],
existing_type=sa.Text(),
type_=MediumText(),
existing_nullable=column not in NOT_NULL_COLUMNS,
)
for item in TABLE_COLUMNS:
table_name, column_name = item.split(".")

if (column := get_table_column(table_name, column_name)) and isinstance(
column["type"],
TEXT,
):
with op.batch_alter_table(table_name) as batch_op:
batch_op.alter_column(
column_name,
existing_type=sa.Text(),
type_=MediumText(),
existing_nullable=item not in NOT_NULL_COLUMNS,
)


def downgrade():
if isinstance(op.get_bind().dialect, MySQLDialect):
for column in TABLE_COLUMNS:
with op.batch_alter_table(column.split(".")[0]) as batch_op:
batch_op.alter_column(
column.split(".")[1],
existing_type=MediumText(),
type_=sa.Text(),
existing_nullable=column not in NOT_NULL_COLUMNS,
)
for item in TABLE_COLUMNS:
table_name, column_name = item.split(".")

if (column := get_table_column(table_name, column_name)) and isinstance(
column["type"],
MEDIUMTEXT,
):
with op.batch_alter_table(table_name) as batch_op:
batch_op.alter_column(
column_name,
existing_type=MediumText(),
type_=sa.Text(),
existing_nullable=item not in NOT_NULL_COLUMNS,
)
14 changes: 10 additions & 4 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@
)
from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label
from superset.utils.core import (
get_column_name,
LongText,
MediumText,
QueryStatus,
user_label,
)

if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
Expand Down Expand Up @@ -107,11 +113,11 @@ class Query(
tab_name = Column(String(256))
sql_editor_id = Column(String(256))
schema = Column(String(256))
sql = Column(MediumText())
sql = Column(LongText())
# Query to retrieve the results,
# used only in case of select_as_cta_used is true.
select_sql = Column(MediumText())
executed_sql = Column(MediumText())
select_sql = Column(LongText())
executed_sql = Column(LongText())
# Could be configured in the superset config.
limit = Column(Integer)
limiting_factor = Column(
Expand Down
6 changes: 5 additions & 1 deletion superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from pandas.api.types import infer_dtype
from pandas.core.dtypes.common import is_numeric_dtype
from sqlalchemy import event, exc, inspect, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.dialects.mysql import LONGTEXT, MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant
Expand Down Expand Up @@ -1469,6 +1469,10 @@ def MediumText() -> Variant: # pylint:disable=invalid-name
return Text().with_variant(MEDIUMTEXT(), "mysql")


def LongText() -> Variant: # pylint:disable=invalid-name
return Text().with_variant(LONGTEXT(), "mysql")


def shortid() -> str:
return f"{uuid.uuid4()}"[-12:]

Expand Down
16 changes: 9 additions & 7 deletions superset/utils/pandas_postprocessing/contribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def contribution(
if len(rename_columns) != len(actual_columns):
raise InvalidPostProcessingError(
_(
"`rename_columns` must have the same length as `columns` + `time_shift_columns`."
"`rename_columns` must have the same length as "
+ "`columns` + `time_shift_columns`."
)
)
# limit to selected columns
Expand Down Expand Up @@ -105,10 +106,10 @@ def get_column_groups(
:param df: DataFrame to group columns from
:param time_shifts: List of time shifts to group by
:param rename_columns: List of new column names
:return: Dictionary with two keys: 'non_time_shift' and 'time_shifts'. 'non_time_shift'
maps to a tuple of original and renamed columns without a time shift. 'time_shifts' maps
to a dictionary where each key is a time shift and each value is a tuple of original and
renamed columns with that time shift.
:return: Dictionary with two keys: 'non_time_shift' and 'time_shifts'.
'non_time_shift' maps to a tuple of original and renamed columns without a time shift.
'time_shifts' maps to a dictionary where each key is a time shift and each value is a
tuple of original and renamed columns with that time shift.
"""
result: dict[str, Any] = {
"non_time_shift": ([], []), # take the form of ([A, B, C], [X, Y, Z])
Expand Down Expand Up @@ -139,8 +140,9 @@ def calculate_row_contribution(
"""
Calculate the contribution of each column to the row total and update the DataFrame.
This function calculates the contribution of each selected column to the total of the row,
and updates the DataFrame with these contribution percentages in place of the original values.
This function calculates the contribution of each selected column to the total of
the row, and updates the DataFrame with these contribution percentages in place of
the original values.
:param df: The DataFrame to calculate contributions for.
:param columns: A list of column names to calculate contributions for.
Expand Down

0 comments on commit f855a5e

Please sign in to comment.