diff --git a/superset-frontend/src/explore/controlPanels/BigNumber.js b/superset-frontend/src/explore/controlPanels/BigNumber.js
index ab01866ce10eb..4a708f92b8078 100644
--- a/superset-frontend/src/explore/controlPanels/BigNumber.js
+++ b/superset-frontend/src/explore/controlPanels/BigNumber.js
@@ -17,6 +17,7 @@
* under the License.
*/
import { t } from '@superset-ui/translation';
+import React from 'react';
export default {
controlPanelSections: [
@@ -43,6 +44,14 @@ export default {
['subheader_font_size'],
],
},
+ {
+ label: t('Advanced Analytics'),
+ expanded: false,
+ controlSetRows: [
+ [
{t('Rolling Window')}
],
+ ['rolling_type', 'rolling_periods', 'min_periods'],
+ ],
+ },
],
controlOverrides: {
y_axis_format: {
diff --git a/superset-frontend/src/explore/controlPanels/sections.jsx b/superset-frontend/src/explore/controlPanels/sections.jsx
index 148b6a92c3d2f..ef63dcae7ed2b 100644
--- a/superset-frontend/src/explore/controlPanels/sections.jsx
+++ b/superset-frontend/src/explore/controlPanels/sections.jsx
@@ -75,7 +75,7 @@ export const NVD3TimeSeries = [
'of query results',
),
controlSetRows: [
- [{t('Moving Average')}
],
+ [{t('Rolling Window')}
],
['rolling_type', 'rolling_periods', 'min_periods'],
[{t('Time Comparison')}
],
['time_compare', 'comparison_type'],
diff --git a/superset-frontend/src/explore/controls.jsx b/superset-frontend/src/explore/controls.jsx
index b477e70b1e9c5..a553a88300a25 100644
--- a/superset-frontend/src/explore/controls.jsx
+++ b/superset-frontend/src/explore/controls.jsx
@@ -1126,7 +1126,7 @@ export const controls = {
rolling_type: {
type: 'SelectControl',
- label: t('Rolling'),
+ label: t('Rolling Function'),
default: 'None',
choices: formatSelectOptions(['None', 'mean', 'sum', 'std', 'cumsum']),
description: t(
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index 674a6bb5d5f92..e20c235dd026e 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -106,22 +106,23 @@ def load_birth_names(only_metadata=False, force=False):
obj.fetch_metadata()
tbl = obj
+ metrics = [
+ {
+ "expressionType": "SIMPLE",
+ "column": {"column_name": "num", "type": "BIGINT"},
+ "aggregate": "SUM",
+ "label": "Births",
+ "optionName": "metric_11",
+ }
+ ]
+ metric = "sum__num"
+
defaults = {
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity_sqla": "ds",
"groupby": [],
- "metric": "sum__num",
- "metrics": [
- {
- "expressionType": "SIMPLE",
- "column": {"column_name": "num", "type": "BIGINT"},
- "aggregate": "SUM",
- "label": "Births",
- "optionName": "metric_11",
- }
- ],
"row_limit": config["ROW_LIMIT"],
"since": "100 years ago",
"until": "now",
@@ -144,6 +145,7 @@ def load_birth_names(only_metadata=False, force=False):
granularity_sqla="ds",
compare_lag="5",
compare_suffix="over 5Y",
+ metric=metric,
),
),
Slice(
@@ -151,7 +153,9 @@ def load_birth_names(only_metadata=False, force=False):
viz_type="pie",
datasource_type="table",
datasource_id=tbl.id,
- params=get_slice_json(defaults, viz_type="pie", groupby=["gender"]),
+ params=get_slice_json(
+ defaults, viz_type="pie", groupby=["gender"], metric=metric
+ ),
),
Slice(
slice_name="Trends",
@@ -165,6 +169,7 @@ def load_birth_names(only_metadata=False, force=False):
granularity_sqla="ds",
rich_tooltip=True,
show_legend=True,
+ metrics=metrics,
),
),
Slice(
@@ -215,6 +220,7 @@ def load_birth_names(only_metadata=False, force=False):
adhoc_filters=[gen_filter("gender", "girl")],
row_limit=50,
timeseries_limit_metric="sum__num",
+ metrics=metrics,
),
),
Slice(
@@ -231,6 +237,7 @@ def load_birth_names(only_metadata=False, force=False):
rotation="square",
limit="100",
adhoc_filters=[gen_filter("gender", "girl")],
+ metric=metric,
),
),
Slice(
@@ -243,6 +250,7 @@ def load_birth_names(only_metadata=False, force=False):
groupby=["name"],
adhoc_filters=[gen_filter("gender", "boy")],
row_limit=50,
+ metrics=metrics,
),
),
Slice(
@@ -259,6 +267,7 @@ def load_birth_names(only_metadata=False, force=False):
rotation="square",
limit="100",
adhoc_filters=[gen_filter("gender", "boy")],
+ metric=metric,
),
),
Slice(
@@ -276,6 +285,7 @@ def load_birth_names(only_metadata=False, force=False):
time_grain_sqla="P1D",
viz_type="area",
x_axis_forma="smart_date",
+ metrics=metrics,
),
),
Slice(
@@ -293,6 +303,7 @@ def load_birth_names(only_metadata=False, force=False):
time_grain_sqla="P1D",
viz_type="area",
x_axis_forma="smart_date",
+ metrics=metrics,
),
),
]
@@ -314,6 +325,7 @@ def load_birth_names(only_metadata=False, force=False):
},
metric_2="sum__num",
granularity_sqla="ds",
+ metrics=metrics,
),
),
Slice(
@@ -321,7 +333,7 @@ def load_birth_names(only_metadata=False, force=False):
viz_type="line",
datasource_type="table",
datasource_id=tbl.id,
- params=get_slice_json(defaults, viz_type="line"),
+ params=get_slice_json(defaults, viz_type="line", metrics=metrics),
),
Slice(
slice_name="Daily Totals",
@@ -335,6 +347,7 @@ def load_birth_names(only_metadata=False, force=False):
since="40 years ago",
until="now",
viz_type="table",
+ metrics=metrics,
),
),
Slice(
@@ -397,6 +410,7 @@ def load_birth_names(only_metadata=False, force=False):
datasource_id=tbl.id,
params=get_slice_json(
defaults,
+ metrics=metrics,
groupby=["name"],
row_limit=50,
timeseries_limit_metric={
@@ -417,6 +431,7 @@ def load_birth_names(only_metadata=False, force=False):
datasource_id=tbl.id,
params=get_slice_json(
defaults,
+ metric=metric,
viz_type="big_number_total",
granularity_sqla="ds",
adhoc_filters=[gen_filter("gender", "girl")],
@@ -429,7 +444,11 @@ def load_birth_names(only_metadata=False, force=False):
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
- defaults, viz_type="pivot_table", groupby=["name"], columns=["state"]
+ defaults,
+ viz_type="pivot_table",
+ groupby=["name"],
+ columns=["state"],
+ metrics=metrics,
),
),
]
diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py
index 695a80a0bedaa..b30d07fe0f320 100644
--- a/superset/examples/world_bank.py
+++ b/superset/examples/world_bank.py
@@ -97,31 +97,32 @@ def load_world_bank_health_n_pop(
db.session.commit()
tbl.fetch_metadata()
+ metric = "sum__SP_POP_TOTL"
+ metrics = ["sum__SP_POP_TOTL"]
+ secondary_metric = {
+ "aggregate": "SUM",
+ "column": {
+ "column_name": "SP_RUR_TOTL",
+ "optionName": "_col_SP_RUR_TOTL",
+ "type": "DOUBLE",
+ },
+ "expressionType": "SIMPLE",
+ "hasCustomLabel": True,
+ "label": "Rural Population",
+ }
+
defaults = {
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity_sqla": "year",
"groupby": [],
- "metric": "sum__SP_POP_TOTL",
- "metrics": ["sum__SP_POP_TOTL"],
"row_limit": config["ROW_LIMIT"],
"since": "2014-01-01",
"until": "2014-01-02",
"time_range": "2014-01-01 : 2014-01-02",
"markup_type": "markdown",
"country_fieldtype": "cca3",
- "secondary_metric": {
- "aggregate": "SUM",
- "column": {
- "column_name": "SP_RUR_TOTL",
- "optionName": "_col_SP_RUR_TOTL",
- "type": "DOUBLE",
- },
- "expressionType": "SIMPLE",
- "hasCustomLabel": True,
- "label": "Rural Population",
- },
"entity": "country_code",
"show_bubbles": True,
}
@@ -207,6 +208,7 @@ def load_world_bank_health_n_pop(
viz_type="world_map",
metric="sum__SP_RUR_TOTL_ZS",
num_period_compare="10",
+ secondary_metric=secondary_metric,
),
),
Slice(
@@ -264,6 +266,8 @@ def load_world_bank_health_n_pop(
groupby=["region", "country_name"],
since="2011-01-01",
until="2011-01-01",
+ metric=metric,
+ secondary_metric=secondary_metric,
),
),
Slice(
@@ -277,6 +281,7 @@ def load_world_bank_health_n_pop(
until="now",
viz_type="area",
groupby=["region"],
+ metrics=metrics,
),
),
Slice(
@@ -292,6 +297,7 @@ def load_world_bank_health_n_pop(
x_ticks_layout="staggered",
viz_type="box_plot",
groupby=["region"],
+ metrics=metrics,
),
),
Slice(
diff --git a/superset/viz.py b/superset/viz.py
index 80e6d8a6f8332..433fede19d5c4 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -178,6 +178,26 @@ def run_extra_queries(self):
"""
pass
+ def apply_rolling(self, df):
+ fd = self.form_data
+ rolling_type = fd.get("rolling_type")
+ rolling_periods = int(fd.get("rolling_periods") or 0)
+ min_periods = int(fd.get("min_periods") or 0)
+
+ if rolling_type in ("mean", "std", "sum") and rolling_periods:
+ kwargs = dict(window=rolling_periods, min_periods=min_periods)
+ if rolling_type == "mean":
+ df = df.rolling(**kwargs).mean()
+ elif rolling_type == "std":
+ df = df.rolling(**kwargs).std()
+ elif rolling_type == "sum":
+ df = df.rolling(**kwargs).sum()
+ elif rolling_type == "cumsum":
+ df = df.cumsum()
+ if min_periods:
+ df = df[min_periods:]
+ return df
+
def get_samples(self):
query_obj = self.query_obj()
query_obj.update(
@@ -1101,6 +1121,18 @@ def query_obj(self):
self.form_data["metric"] = metric
return d
+ def get_data(self, df: pd.DataFrame) -> VizData:
+ df = df.pivot_table(
+ index=DTTM_ALIAS,
+ columns=[],
+ values=self.metric_labels,
+ fill_value=0,
+ aggfunc=sum,
+ )
+ df = self.apply_rolling(df)
+ df[DTTM_ALIAS] = df.index
+ return super().get_data(df)
+
class BigNumberTotalViz(BaseViz):
@@ -1225,23 +1257,7 @@ def process_data(self, df: pd.DataFrame, aggregate: bool = False) -> VizData:
dfs.sort_values(ascending=False, inplace=True)
df = df[dfs.index]
- rolling_type = fd.get("rolling_type")
- rolling_periods = int(fd.get("rolling_periods") or 0)
- min_periods = int(fd.get("min_periods") or 0)
-
- if rolling_type in ("mean", "std", "sum") and rolling_periods:
- kwargs = dict(window=rolling_periods, min_periods=min_periods)
- if rolling_type == "mean":
- df = df.rolling(**kwargs).mean()
- elif rolling_type == "std":
- df = df.rolling(**kwargs).std()
- elif rolling_type == "sum":
- df = df.rolling(**kwargs).sum()
- elif rolling_type == "cumsum":
- df = df.cumsum()
- if min_periods:
- df = df[min_periods:]
-
+ df = self.apply_rolling(df)
if fd.get("contribution"):
dft = df.T
df = (dft / dft.sum()).T
diff --git a/tests/viz_tests.py b/tests/viz_tests.py
index 6f23c6a610326..ec318e606f695 100644
--- a/tests/viz_tests.py
+++ b/tests/viz_tests.py
@@ -1192,3 +1192,54 @@ def test_process_data_resample(self):
.tolist(),
[1.0, 2.0, np.nan, np.nan, 5.0, np.nan, 7.0],
)
+
+ def test_apply_rolling(self):
+ datasource = self.get_datasource_mock()
+ df = pd.DataFrame(
+ index=pd.to_datetime(
+ ["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]
+ ),
+ data={"y": [1.0, 2.0, 3.0, 4.0]},
+ )
+ self.assertEqual(
+ viz.BigNumberViz(
+ datasource,
+ {
+ "metrics": ["y"],
+ "rolling_type": "cumsum",
+ "rolling_periods": 0,
+ "min_periods": 0,
+ },
+ )
+ .apply_rolling(df)["y"]
+ .tolist(),
+ [1.0, 3.0, 6.0, 10.0],
+ )
+ self.assertEqual(
+ viz.BigNumberViz(
+ datasource,
+ {
+ "metrics": ["y"],
+ "rolling_type": "sum",
+ "rolling_periods": 2,
+ "min_periods": 0,
+ },
+ )
+ .apply_rolling(df)["y"]
+ .tolist(),
+ [1.0, 3.0, 5.0, 7.0],
+ )
+ self.assertEqual(
+ viz.BigNumberViz(
+ datasource,
+ {
+ "metrics": ["y"],
+ "rolling_type": "mean",
+ "rolling_periods": 10,
+ "min_periods": 0,
+ },
+ )
+ .apply_rolling(df)["y"]
+ .tolist(),
+ [1.0, 1.5, 2.0, 2.5],
+ )