Skip to content

Commit

Permalink
Update AggregateRel.CopyWithExpressionRewrite to rewrite measure func…
Browse files Browse the repository at this point in the history
…tions in addition to filters (#112)
  • Loading branch information
EpsilonPrime authored Jan 29, 2025
1 parent d6e63d9 commit 60cc74f
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 15 deletions.
22 changes: 22 additions & 0 deletions expr/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,28 @@ func (a *AggregateFunction) IntermediateType() (types.FuncDefArgType, error) {
return a.declaration.Intermediate()
}

func (a *AggregateFunction) Clone() *AggregateFunction {
newA := *a
if a.args != nil {
newA.args = make([]types.FuncArg, len(a.args))
copy(newA.args, a.args)
}
if a.options != nil {
newA.options = make([]*types.FunctionOption, len(a.options))
copy(newA.options, a.options)
}
if a.Sorts != nil {
newA.Sorts = make([]SortField, len(a.Sorts))
copy(newA.Sorts, a.Sorts)
}
return &newA
}

// SetArg sets the specified argument to the provided value. The index is not checked for validity.
func (a *AggregateFunction) SetArg(i int, arg types.FuncArg) {
a.args[i] = arg
}

func (a *AggregateFunction) String() string {
var b strings.Builder

Expand Down
6 changes: 6 additions & 0 deletions expr/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ type ExtensionRegistry struct {
c *extensions.Collection
}

// NewExtensionRegistry creates a new registry. If you have an existing plan you can use GetExtensionSet() to
// populate an extensions.Set.
func NewExtensionRegistry(extSet extensions.Set, c *extensions.Collection) ExtensionRegistry {
if c == nil {
panic("cannot create registry with nil collection")
}
return ExtensionRegistry{Set: extSet, c: c}
}

// NewEmptyExtensionRegistry creates an empty registry useful starting from scratch.
func NewEmptyExtensionRegistry(c *extensions.Collection) ExtensionRegistry {
return NewExtensionRegistry(extensions.NewSet(), c)
}
Expand All @@ -28,14 +31,17 @@ func (e *ExtensionRegistry) LookupType(anchor uint32) (extensions.Type, bool) {
return e.Set.LookupType(anchor, e.c)
}

// LookupScalarFunction returns a ScalarFunctionVariant associated with a previously used function's anchor.
func (e *ExtensionRegistry) LookupScalarFunction(anchor uint32) (*extensions.ScalarFunctionVariant, bool) {
return e.Set.LookupScalarFunction(anchor, e.c)
}

// LookupAggregateFunction returns an AggregateFunctionVariant associated with a previously used function's anchor.
func (e *ExtensionRegistry) LookupAggregateFunction(anchor uint32) (*extensions.AggregateFunctionVariant, bool) {
return e.Set.LookupAggregateFunction(anchor, e.c)
}

// LookupWindowFunction returns a WindowFunctionVariant associated with a previously used function's anchor.
func (e *ExtensionRegistry) LookupWindowFunction(anchor uint32) (*extensions.WindowFunctionVariant, bool) {
return e.Set.LookupWindowFunction(anchor, e.c)
}
30 changes: 28 additions & 2 deletions plan/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,30 @@ func (ar *AggregateRel) Copy(newInputs ...Rel) (Rel, error) {
return &aggregate, nil
}

func (ar *AggregateRel) rewriteAggregateFunc(rewriteFunc RewriteFunc, f *expr.AggregateFunction) (*expr.AggregateFunction, error) {
if f == nil {
return f, nil
}
newF := f.Clone()
argsAreEqual := true
for i := 0; i < f.NArgs(); i++ {
arg := f.Arg(i)
if exp, ok := arg.(expr.Expression); ok {
var newExp expr.Expression
var err error
if newExp, err = rewriteFunc(exp); err != nil {
return nil, err
}
newF.SetArg(i, newExp)
argsAreEqual = argsAreEqual && exp == newExp
}
}
if argsAreEqual {
return f, nil
}
return newF, nil
}

func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs ...Rel) (Rel, error) {
if len(newInputs) != 1 {
return nil, substraitgo.ErrInvalidInputCount
Expand All @@ -1187,8 +1211,10 @@ func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn
if newMeasures[i].filter, err = rewriteFunc(m.filter); err != nil {
return nil, err
}
measuresAreEqual = measuresAreEqual && newMeasures[i].filter == m.filter
newMeasures[i].measure = m.measure
if newMeasures[i].measure, err = ar.rewriteAggregateFunc(rewriteFunc, m.measure); err != nil {
return nil, err
}
measuresAreEqual = measuresAreEqual && newMeasures[i].filter == m.filter && newMeasures[i].measure == m.measure
}
if groupsAreEqual && measuresAreEqual && newInputs[0] == ar.input {
return ar, nil
Expand Down
52 changes: 39 additions & 13 deletions plan/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/v3/expr"
"github.com/substrait-io/substrait-go/v3/extensions"
"github.com/substrait-io/substrait-go/v3/proto"
"github.com/substrait-io/substrait-go/v3/types"
)
Expand All @@ -26,8 +28,24 @@ func createPrimitiveBool(value bool) expr.Expression {
}

func TestRelations_Copy(t *testing.T) {
aggregateRel := &AggregateRel{input: createVirtualTableReadRel(1), groupingExpressions: []expr.Expression{createPrimitiveFloat(1.0)}, groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{filter: expr.NewPrimitiveLiteral(false, false)}}}
extReg := expr.NewExtensionRegistry(extensions.NewSet(), &extensions.DefaultCollection)
aggregateFnID := extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
Name: "avg",
}
aggregateFn, err := expr.NewAggregateFunc(extReg,
aggregateFnID, nil, types.AggInvocationAll,
types.AggPhaseInitialToResult, nil, createPrimitiveFloat(1.0))
require.NoError(t, err)
aggregateFnRevised, err := expr.NewAggregateFunc(extReg,
aggregateFnID, nil, types.AggInvocationAll,
types.AggPhaseInitialToResult, nil, createPrimitiveFloat(9.0))
require.NoError(t, err)

aggregateRel := &AggregateRel{input: createVirtualTableReadRel(1),
groupingExpressions: []expr.Expression{createPrimitiveFloat(1.0)},
groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{measure: aggregateFn, filter: expr.NewPrimitiveLiteral(false, false)}}}
crossRel := &CrossRel{left: createVirtualTableReadRel(1), right: createVirtualTableReadRel(2)}
extensionLeafRel := &ExtensionLeafRel{}
extensionMultiRel := &ExtensionMultiRel{inputs: []Rel{createVirtualTableReadRel(1), createVirtualTableReadRel(2)}}
Expand Down Expand Up @@ -60,10 +78,13 @@ func TestRelations_Copy(t *testing.T) {
}
testCases := []relationTestCase{
{
name: "AggregateRel Copy with new inputs",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(6)},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(6), groupingReferences: aggregateRel.groupingReferences, groupingExpressions: aggregateRel.groupingExpressions, measures: aggregateRel.measures},
name: "AggregateRel Copy with new inputs",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(6)},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(6),
groupingReferences: aggregateRel.groupingReferences,
groupingExpressions: aggregateRel.groupingExpressions,
measures: aggregateRel.measures},
},
{
name: "AggregateRel Copy with same inputs and noOpRewrite",
Expand All @@ -73,13 +94,16 @@ func TestRelations_Copy(t *testing.T) {
expectedSameRel: true,
},
{
name: "AggregateRel Copy with new Inputs and noOpReWrite",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(7)},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(7), groupingExpressions: aggregateRel.groupingExpressions, groupingReferences: aggregateRel.groupingReferences, measures: aggregateRel.measures},
name: "AggregateRel Copy with new Inputs and noOpReWrite",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(7)},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(7),
groupingExpressions: aggregateRel.groupingExpressions,
groupingReferences: aggregateRel.groupingReferences,
measures: aggregateRel.measures},
},
{
name: "AggregateRel Copy with new Inputs and reWriteFunc",
name: "AggregateRel Copy with new Inputs and rewriteFunc",
relation: aggregateRel,
newInputs: []Rel{createVirtualTableReadRel(8)},
rewriteFunc: func(expression expr.Expression) (expr.Expression, error) {
Expand All @@ -91,8 +115,10 @@ func TestRelations_Copy(t *testing.T) {
}
panic("unexpected expression type")
},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(8), groupingExpressions: []expr.Expression{createPrimitiveFloat(9.0)}, groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{filter: expr.NewPrimitiveLiteral(true, false)}}},
expectedRel: &AggregateRel{input: createVirtualTableReadRel(8),
groupingExpressions: []expr.Expression{createPrimitiveFloat(9.0)},
groupingReferences: [][]uint32{{0}},
measures: []AggRelMeasure{{measure: aggregateFnRevised, filter: expr.NewPrimitiveLiteral(true, false)}}},
},
{
name: "ExtensionLeafRel Copy with new inputs",
Expand Down

0 comments on commit 60cc74f

Please sign in to comment.