diff --git a/superset-frontend/src/explore/controlPanels/BigNumber.js b/superset-frontend/src/explore/controlPanels/BigNumber.js index ab01866ce10eb..4a708f92b8078 100644 --- a/superset-frontend/src/explore/controlPanels/BigNumber.js +++ b/superset-frontend/src/explore/controlPanels/BigNumber.js @@ -17,6 +17,7 @@ * under the License. */ import { t } from '@superset-ui/translation'; +import React from 'react'; export default { controlPanelSections: [ @@ -43,6 +44,14 @@ export default { ['subheader_font_size'], ], }, + { + label: t('Advanced Analytics'), + expanded: false, + controlSetRows: [ + [

{t('Rolling Window')}

], + ['rolling_type', 'rolling_periods', 'min_periods'], + ], + }, ], controlOverrides: { y_axis_format: { diff --git a/superset-frontend/src/explore/controlPanels/sections.jsx b/superset-frontend/src/explore/controlPanels/sections.jsx index 148b6a92c3d2f..ef63dcae7ed2b 100644 --- a/superset-frontend/src/explore/controlPanels/sections.jsx +++ b/superset-frontend/src/explore/controlPanels/sections.jsx @@ -75,7 +75,7 @@ export const NVD3TimeSeries = [ 'of query results', ), controlSetRows: [ - [

{t('Moving Average')}

], + [

{t('Rolling Window')}

], ['rolling_type', 'rolling_periods', 'min_periods'], [

{t('Time Comparison')}

], ['time_compare', 'comparison_type'], diff --git a/superset-frontend/src/explore/controls.jsx b/superset-frontend/src/explore/controls.jsx index b477e70b1e9c5..a553a88300a25 100644 --- a/superset-frontend/src/explore/controls.jsx +++ b/superset-frontend/src/explore/controls.jsx @@ -1126,7 +1126,7 @@ export const controls = { rolling_type: { type: 'SelectControl', - label: t('Rolling'), + label: t('Rolling Function'), default: 'None', choices: formatSelectOptions(['None', 'mean', 'sum', 'std', 'cumsum']), description: t( diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 674a6bb5d5f92..e20c235dd026e 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -106,22 +106,23 @@ def load_birth_names(only_metadata=False, force=False): obj.fetch_metadata() tbl = obj + metrics = [ + { + "expressionType": "SIMPLE", + "column": {"column_name": "num", "type": "BIGINT"}, + "aggregate": "SUM", + "label": "Births", + "optionName": "metric_11", + } + ] + metric = "sum__num" + defaults = { "compare_lag": "10", "compare_suffix": "o10Y", "limit": "25", "granularity_sqla": "ds", "groupby": [], - "metric": "sum__num", - "metrics": [ - { - "expressionType": "SIMPLE", - "column": {"column_name": "num", "type": "BIGINT"}, - "aggregate": "SUM", - "label": "Births", - "optionName": "metric_11", - } - ], "row_limit": config["ROW_LIMIT"], "since": "100 years ago", "until": "now", @@ -144,6 +145,7 @@ def load_birth_names(only_metadata=False, force=False): granularity_sqla="ds", compare_lag="5", compare_suffix="over 5Y", + metric=metric, ), ), Slice( @@ -151,7 +153,9 @@ def load_birth_names(only_metadata=False, force=False): viz_type="pie", datasource_type="table", datasource_id=tbl.id, - params=get_slice_json(defaults, viz_type="pie", groupby=["gender"]), + params=get_slice_json( + defaults, viz_type="pie", groupby=["gender"], metric=metric + ), ), Slice( slice_name="Trends", @@ -165,6 +169,7 @@ def load_birth_names(only_metadata=False, force=False): granularity_sqla="ds", rich_tooltip=True, show_legend=True, + metrics=metrics, ), ), Slice( @@ -215,6 +220,7 @@ def load_birth_names(only_metadata=False, force=False): adhoc_filters=[gen_filter("gender", "girl")], row_limit=50, timeseries_limit_metric="sum__num", + metrics=metrics, ), ), Slice( @@ -231,6 +237,7 @@ def load_birth_names(only_metadata=False, force=False): rotation="square", limit="100", adhoc_filters=[gen_filter("gender", "girl")], + metric=metric, ), ), Slice( @@ -243,6 +250,7 @@ def load_birth_names(only_metadata=False, force=False): groupby=["name"], adhoc_filters=[gen_filter("gender", "boy")], row_limit=50, + metrics=metrics, ), ), Slice( @@ -259,6 +267,7 @@ def load_birth_names(only_metadata=False, force=False): rotation="square", limit="100", adhoc_filters=[gen_filter("gender", "boy")], + metric=metric, ), ), Slice( @@ -276,6 +285,7 @@ def load_birth_names(only_metadata=False, force=False): time_grain_sqla="P1D", viz_type="area", x_axis_forma="smart_date", + metrics=metrics, ), ), Slice( @@ -293,6 +303,7 @@ def load_birth_names(only_metadata=False, force=False): time_grain_sqla="P1D", viz_type="area", x_axis_forma="smart_date", + metrics=metrics, ), ), ] @@ -314,6 +325,7 @@ def load_birth_names(only_metadata=False, force=False): }, metric_2="sum__num", granularity_sqla="ds", + metrics=metrics, ), ), Slice( @@ -321,7 +333,7 @@ def load_birth_names(only_metadata=False, force=False): viz_type="line", datasource_type="table", datasource_id=tbl.id, - params=get_slice_json(defaults, viz_type="line"), + params=get_slice_json(defaults, viz_type="line", metrics=metrics), ), Slice( slice_name="Daily Totals", @@ -335,6 +347,7 @@ def load_birth_names(only_metadata=False, force=False): since="40 years ago", until="now", viz_type="table", + metrics=metrics, ), ), Slice( @@ -397,6 +410,7 @@ def load_birth_names(only_metadata=False, force=False): datasource_id=tbl.id, params=get_slice_json( defaults, + metrics=metrics, groupby=["name"], row_limit=50, timeseries_limit_metric={ @@ -417,6 +431,7 @@ def load_birth_names(only_metadata=False, force=False): datasource_id=tbl.id, params=get_slice_json( defaults, + metric=metric, viz_type="big_number_total", granularity_sqla="ds", adhoc_filters=[gen_filter("gender", "girl")], @@ -429,7 +444,11 @@ def load_birth_names(only_metadata=False, force=False): datasource_type="table", datasource_id=tbl.id, params=get_slice_json( - defaults, viz_type="pivot_table", groupby=["name"], columns=["state"] + defaults, + viz_type="pivot_table", + groupby=["name"], + columns=["state"], + metrics=metrics, ), ), ] diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 695a80a0bedaa..b30d07fe0f320 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -97,31 +97,32 @@ def load_world_bank_health_n_pop( db.session.commit() tbl.fetch_metadata() + metric = "sum__SP_POP_TOTL" + metrics = ["sum__SP_POP_TOTL"] + secondary_metric = { + "aggregate": "SUM", + "column": { + "column_name": "SP_RUR_TOTL", + "optionName": "_col_SP_RUR_TOTL", + "type": "DOUBLE", + }, + "expressionType": "SIMPLE", + "hasCustomLabel": True, + "label": "Rural Population", + } + defaults = { "compare_lag": "10", "compare_suffix": "o10Y", "limit": "25", "granularity_sqla": "year", "groupby": [], - "metric": "sum__SP_POP_TOTL", - "metrics": ["sum__SP_POP_TOTL"], "row_limit": config["ROW_LIMIT"], "since": "2014-01-01", "until": "2014-01-02", "time_range": "2014-01-01 : 2014-01-02", "markup_type": "markdown", "country_fieldtype": "cca3", - "secondary_metric": { - "aggregate": "SUM", - "column": { - "column_name": "SP_RUR_TOTL", - "optionName": "_col_SP_RUR_TOTL", - "type": "DOUBLE", - }, - "expressionType": "SIMPLE", - "hasCustomLabel": True, - "label": "Rural Population", - }, "entity": "country_code", "show_bubbles": True, } @@ -207,6 +208,7 @@ def load_world_bank_health_n_pop( viz_type="world_map", metric="sum__SP_RUR_TOTL_ZS", num_period_compare="10", + secondary_metric=secondary_metric, ), ), Slice( @@ -264,6 +266,8 @@ def load_world_bank_health_n_pop( groupby=["region", "country_name"], since="2011-01-01", until="2011-01-01", + metric=metric, + secondary_metric=secondary_metric, ), ), Slice( @@ -277,6 +281,7 @@ def load_world_bank_health_n_pop( until="now", viz_type="area", groupby=["region"], + metrics=metrics, ), ), Slice( @@ -292,6 +297,7 @@ def load_world_bank_health_n_pop( x_ticks_layout="staggered", viz_type="box_plot", groupby=["region"], + metrics=metrics, ), ), Slice( diff --git a/superset/viz.py b/superset/viz.py index 80e6d8a6f8332..433fede19d5c4 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -178,6 +178,26 @@ def run_extra_queries(self): """ pass + def apply_rolling(self, df): + fd = self.form_data + rolling_type = fd.get("rolling_type") + rolling_periods = int(fd.get("rolling_periods") or 0) + min_periods = int(fd.get("min_periods") or 0) + + if rolling_type in ("mean", "std", "sum") and rolling_periods: + kwargs = dict(window=rolling_periods, min_periods=min_periods) + if rolling_type == "mean": + df = df.rolling(**kwargs).mean() + elif rolling_type == "std": + df = df.rolling(**kwargs).std() + elif rolling_type == "sum": + df = df.rolling(**kwargs).sum() + elif rolling_type == "cumsum": + df = df.cumsum() + if min_periods: + df = df[min_periods:] + return df + def get_samples(self): query_obj = self.query_obj() query_obj.update( @@ -1101,6 +1121,18 @@ def query_obj(self): self.form_data["metric"] = metric return d + def get_data(self, df: pd.DataFrame) -> VizData: + df = df.pivot_table( + index=DTTM_ALIAS, + columns=[], + values=self.metric_labels, + fill_value=0, + aggfunc=sum, + ) + df = self.apply_rolling(df) + df[DTTM_ALIAS] = df.index + return super().get_data(df) + class BigNumberTotalViz(BaseViz): @@ -1225,23 +1257,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData: dfs.sort_values(ascending=False, inplace=True) df = df[dfs.index] - rolling_type = fd.get("rolling_type") - rolling_periods = int(fd.get("rolling_periods") or 0) - min_periods = int(fd.get("min_periods") or 0) - - if rolling_type in ("mean", "std", "sum") and rolling_periods: - kwargs = dict(window=rolling_periods, min_periods=min_periods) - if rolling_type == "mean": - df = df.rolling(**kwargs).mean() - elif rolling_type == "std": - df = df.rolling(**kwargs).std() - elif rolling_type == "sum": - df = df.rolling(**kwargs).sum() - elif rolling_type == "cumsum": - df = df.cumsum() - if min_periods: - df = df[min_periods:] - + df = self.apply_rolling(df) if fd.get("contribution"): dft = df.T df = (dft / dft.sum()).T diff --git a/tests/viz_tests.py b/tests/viz_tests.py index 6f23c6a610326..ec318e606f695 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -1192,3 +1192,54 @@ def test_process_data_resample(self): .tolist(), [1.0, 2.0, np.nan, np.nan, 5.0, np.nan, 7.0], ) + + def test_apply_rolling(self): + datasource = self.get_datasource_mock() + df = pd.DataFrame( + index=pd.to_datetime( + ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"] + ), + data={"y": [1.0, 2.0, 3.0, 4.0]}, + ) + self.assertEqual( + viz.BigNumberViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "cumsum", + "rolling_periods": 0, + "min_periods": 0, + }, + ) + .apply_rolling(df)["y"] + .tolist(), + [1.0, 3.0, 6.0, 10.0], + ) + self.assertEqual( + viz.BigNumberViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "sum", + "rolling_periods": 2, + "min_periods": 0, + }, + ) + .apply_rolling(df)["y"] + .tolist(), + [1.0, 3.0, 5.0, 7.0], + ) + self.assertEqual( + viz.BigNumberViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "mean", + "rolling_periods": 10, + "min_periods": 0, + }, + ) + .apply_rolling(df)["y"] + .tolist(), + [1.0, 1.5, 2.0, 2.5], + )