diff --git a/nvtabular/ops/groupby.py b/nvtabular/ops/groupby.py index da343a5608..1b6084b15e 100644 --- a/nvtabular/ops/groupby.py +++ b/nvtabular/ops/groupby.py @@ -20,7 +20,7 @@ from merlin.core.dispatch import DataFrameType, annotate from merlin.dtypes.shape import DefaultShapes from merlin.schema import Schema -from nvtabular.ops.operator import ColumnSelector, Operator +from nvtabular.ops.operator import ColumnSelector, DataFormats, Operator class Groupby(Operator): @@ -109,6 +109,7 @@ def __init__( self.list_aggs[col] = list(_list_aggs) self.name_sep = name_sep + self.supported_formats = DataFormats.PANDAS_DATAFRAME | DataFormats.CUDF_DATAFRAME super().__init__() @annotate("Groupby_op", color="darkgreen", domain="nvt_python") diff --git a/nvtabular/ops/operator.py b/nvtabular/ops/operator.py index 0757557b12..41b3621643 100644 --- a/nvtabular/ops/operator.py +++ b/nvtabular/ops/operator.py @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from merlin.dag import BaseOperator, ColumnSelector # noqa pylint: disable=unused-import +from merlin.dag import ( # noqa pylint: disable=unused-import + BaseOperator, + ColumnSelector, + DataFormats, +) Operator = BaseOperator