diff --git a/expr/decimal_util.go b/expr/decimal_util.go index ee0aa92..0812b78 100644 --- a/expr/decimal_util.go +++ b/expr/decimal_util.go @@ -105,3 +105,86 @@ func decimalBytesToString(decimalBytes [16]byte, scale int32) string { apdBigInt := new(apd.BigInt).SetMathBigInt(intValue) return apd.NewWithBigInt(apdBigInt, -scale).String() } + +func modifyDecimalPrecisionAndScale(decimalBytes [16]byte, scale, targetPrecision, targetScale int32) ([16]byte, int32, int32, error) { + var result [16]byte + if targetPrecision > 38 { + return result, 0, 0, fmt.Errorf("target precision %d exceeds maximum allowed precision of 38", targetPrecision) + } + + isNegative := decimalBytes[15]&0x80 != 0 + + // Reverse the byte array to convert from little-endian to big-endian. + processingValue := make([]byte, 16) + for i := 0; i < 16; i++ { + processingValue[i] = decimalBytes[15-i] + } + if isNegative { + negate(processingValue[:]) + } + + // Convert the bytes into a big.Int and wrap it into an apd.Decimal. + intValue := new(big.Int).SetBytes(processingValue[:]) + apdBigInt := new(apd.BigInt).SetMathBigInt(intValue) + dec := apd.NewWithBigInt(apdBigInt, -scale) + + // Normalize the decimal by removing trailing zeros. + dec.Reduce(dec) + + // Adjust the scale to the target scale + ctx := apd.BaseContext.WithPrecision(uint32(targetPrecision)) + _, err := ctx.Quantize(dec, dec, -targetScale) + if err != nil { + return result, 0, 0, fmt.Errorf("error adjusting scale: %v", err) + } + + err2 := validatePrecisionAndScale(dec, targetPrecision, targetScale) + if err2 != nil { + return result, 0, 0, err2 + } + + // Convert the adjusted decimal coefficient to a byte array. + byteArray := dec.Coeff.Bytes() + if len(byteArray) > 16 { + return result, 0, 0, fmt.Errorf("number exceeds 16 bytes") + } + copy(result[16-len(byteArray):], byteArray) + + // Handle the sign by applying two's complement for negative numbers. + if isNegative { + negate(result[:]) + } + + // Reverse the byte array back to little-endian. + for i, j := 0, 15; i < j; i, j = i+1, j-1 { + result[i], result[j] = result[j], result[i] + } + + return result, targetPrecision, targetScale, nil +} + +func validatePrecisionAndScale(dec *apd.Decimal, targetPrecision int32, targetScale int32) error { + // Validate the minimum precision and scale. + minPrecision, minScale := getMinimumPrecisionAndScale(dec) + if targetPrecision < minPrecision { + return fmt.Errorf("number %s exceeds target precision %d, minimum precision needed is %d with target scale %d", dec.String(), targetPrecision, minPrecision, targetScale) + } + if targetScale < minScale { + return fmt.Errorf("number %v exceeds target scale %d, minimum scale needed is %d", dec.String(), targetScale, minScale) + } + if targetPrecision-targetScale < minPrecision-minScale { + return fmt.Errorf("number %v exceeds target precision %d with target scale %d, minimum precision needed is %d with minimum scale %d", dec.String(), targetPrecision, targetScale, minPrecision, minScale) + } + return nil +} + +func getMinimumPrecisionAndScale(dec *apd.Decimal) (precision int32, scale int32) { + if dec.Exponent > 0 { + precision = int32(apd.NumDigits(&dec.Coeff)) + dec.Exponent + scale = 0 + } else { + scale = -dec.Exponent + precision = max(int32(apd.NumDigits(&dec.Coeff)), scale+1) + } + return precision, scale +} diff --git a/expr/decimal_util_test.go b/expr/decimal_util_test.go index 346f5e0..06404cf 100644 --- a/expr/decimal_util_test.go +++ b/expr/decimal_util_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDecimalStringToBytes(t *testing.T) { @@ -21,10 +22,18 @@ func TestDecimalStringToBytes(t *testing.T) { {"-12345", "c7cfffffffffffffffffffffffffffff", 5, 0, ""}, {"123.45", "39300000000000000000000000000000", 5, 2, ""}, {"-123.45", "c7cfffffffffffffffffffffffffffff", 5, 2, ""}, + {"1", "01000000000000000000000000000000", 1, 0, ""}, + {"-1.0", "f6ffffffffffffffffffffffffffffff", 2, 1, ""}, + {"-1.00", "9cffffffffffffffffffffffffffffff", 3, 2, ""}, + {"-1.000", "18fcffffffffffffffffffffffffffff", 4, 3, ""}, + {"12345.6789", "15cd5b07000000000000000000000000", 9, 4, ""}, + {"12345.67890000", "5004fb711f0100000000000000000000", 13, 8, ""}, {"0.123", "7b000000000000000000000000000000", 4, 3, ""}, {"-0.123", "85ffffffffffffffffffffffffffffff", 4, 3, ""}, {"9223372036854775807", "ffffffffffffff7f0000000000000000", 19, 0, ""}, // Max int64 {"-9223372036854775808", "0000000000000080ffffffffffffffff", 19, 0, ""}, // Min int64 + {"9223372036854775807.0000", "f0d8ffffffffffff8713000000000000", 23, 4, ""}, + {"-9223372036854775808.00", "0000000000000000ceffffffffffffff", 21, 2, ""}, {"99999999999999999999999999999999999999", "ffffffff3f228a097ac4865aa84c3b4b", 38, 0, ""}, {"+99999999999999999999999999999999999999", "ffffffff3f228a097ac4865aa84c3b4b", 38, 0, ""}, {"-99999999999999999999999999999999999999", "01000000c0dd75f6853b79a557b3c4b4", 38, 0, ""}, @@ -48,25 +57,47 @@ func TestDecimalStringToBytes(t *testing.T) { {"1.23e20", "00000c6d51c8f7aa0600000000000000", 21, 0, "123000000000000000000"}, {"1.23e35", "00000000cebde644bc05f0425eb01700", 36, 0, "123000000000000000000000000000000000"}, {"1.23E35", "00000000cebde644bc05f0425eb01700", 36, 0, "123000000000000000000000000000000000"}, + {"-123456789012345678901234567890.1234", "0e5069812fa37d21cd68009021c3ffff", 34, 4, ""}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { got, precision, scale, err := DecimalStringToBytes(tt.input) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, got, 16) assert.Equal(t, hexToBytes(t, tt.hexWant), got[:]) assert.Equal(t, tt.expPrecision, precision) assert.Equal(t, tt.expScale, scale) - if err == nil { - // verify that the conversion is correct - decStr := decimalBytesToString(got, scale) - if tt.expected == "" { - tt.expected = strings.TrimPrefix(tt.input, "+") - } - assert.Equal(t, tt.expected, decStr) + + // verify that the conversion is correct + decStr := decimalBytesToString(got, scale) + if tt.expected == "" { + tt.expected = strings.TrimPrefix(tt.input, "+") } + assert.Equal(t, tt.expected, decStr) + // test modifyDecimalPrecisionAndScale + targetPrecision := min(precision+2, 38) + targetScale := scale + if precision <= 36 { + targetScale = min(scale+2, targetPrecision) + } + newBytes, newPrecision, newScale, err := modifyDecimalPrecisionAndScale(got, scale, targetPrecision, targetScale) + require.NoError(t, err) + assert.Equal(t, targetPrecision, newPrecision) + decStr = decimalBytesToString(newBytes, newScale) + if tt.expected != decStr { + require.True(t, strings.HasPrefix(decStr, tt.expected)) + suffix := decStr[len(tt.expected):] + assert.LessOrEqual(t, len(suffix), 3) + assert.NotEqual(t, 1, len(suffix)) + switch len(suffix) { + case 2: + assert.Equal(t, "00", suffix) + case 3: + assert.Equal(t, ".00", suffix) + } + } }) } } @@ -143,3 +174,69 @@ func TestDecimalBytesToString(t *testing.T) { }) } } + +func TestModifyDecimalPrecisionAndScale(t *testing.T) { + tests := []struct { + input string + hexInput string + inputPrecision int32 + inputScale int32 + targetPrecision int32 + targetScale int32 + hexWant2 string + expected string + expectError bool + }{ + {"12345", "39300000000000000000000000000000", 5, 0, 10, 0, "", "", false}, + {"12345", "39300000000000000000000000000000", 5, 0, 20, 2, "44D61200000000000000000000000000", "12345.00", false}, + {"12345.00", "44D61200000000000000000000000000", 20, 2, 5, 2, "", "12345.00", true}, + {"12345.00", "44D61200000000000000000000000000", 20, 2, 5, 0, "39300000000000000000000000000000", "12345", false}, + {"12345.6789", "15cd5b07000000000000000000000000", 9, 4, 12, 8, "15cd5b07000000000000000000000000", "12345.67890000", true}, + {"12345.6789", "15cd5b07000000000000000000000000", 9, 4, 13, 8, "5004fb711f0100000000000000000000", "12345.67890000", false}, + {"-1.00", "9cffffffffffffffffffffffffffffff", 3, 2, 5, 3, "18fcffffffffffffffffffffffffffff", "-1.000", false}, + {"-1.0", "f6ffffffffffffffffffffffffffffff", 2, 1, 2, 0, "ffffffffffffffffffffffffffffffff", "-1", false}, + {"1.0", "0a000000000000000000000000000000", 2, 1, 2, 0, "01000000000000000000000000000000", "1", false}, + {"1.0", "0a000000000000000000000000000000", 2, 1, 40, 0, "", "", true}, + {"1.0", "0a000000000000000000000000000000", 2, 1, 1, 0, "01000000000000000000000000000000", "1", false}, + {"1", "01000000000000000000000000000000", 1, 0, 3, 2, "64000000000000000000000000000000", "1.00", false}, + {"9223372036854775807", "ffffffffffffff7f0000000000000000", 19, 0, 30, 4, "f0d8ffffffffffff8713000000000000", "9223372036854775807.0000", false}, + {"-9223372036854775808", "0000000000000080ffffffffffffffff", 19, 0, 30, 2, "0000000000000000ceffffffffffffff", "-9223372036854775808.00", false}, + {"0.0000123", "7b000000000000000000000000000000", 8, 7, 10, 9, "0c300000000000000000000000000000", "0.000012300", false}, + {"1230000000000000", "00e012b1ad5e04000000000000000000", 16, 0, 20, 2, "00805f2bd9fbb4010000000000000000", "1230000000000000.00", false}, + {"123000000000000000000", "00000c6d51c8f7aa0600000000000000", 21, 0, 25, 2, "0000b098ce3fcac89a02000000000000", "123000000000000000000.00", false}, + {"123000000000000000000000000000000000", "00000000cebde644bc05f0425eb01700", 36, 0, 38, 0, "00000000cebde644bc05f0425eb01700", "123000000000000000000000000000000000", false}, + {"123000000000000000000000000000000000", "00000000cebde644bc05f0425eb01700", 36, 0, 38, 2, "00000000782422ea8a3dc225d2e44009", "123000000000000000000000000000000000.00", false}, + {"1234567890123456.78901234", "f2af966ca0101f9b241a000000000000", 24, 8, 28, 10, "88badc6aaa7e22984c360a0000000000", "1234567890123456.7890123400", false}, + {"-1234567890.1234567890", "2ef5e0147356ab54ffffffffffffffff", 20, 10, 24, 12, "f8c5df27f4c4ed12bdffffffffffffff", "-1234567890.123456789000", false}, + {"123456789012345678901234567890.1234", "f2af967ed05c82de3297ff6fde3c0000", 34, 4, 38, 6, "88badc727141eceade0fd7bfe3c61700", "123456789012345678901234567890.123400", false}, + {"-123456789012345678901234567890.1234", "0e5069812fa37d21cd68009021c3ffff", 34, 4, 38, 6, "7845238d8ebe131521f028401c39e8ff", "-123456789012345678901234567890.123400", false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + + inputBytes := [16]byte(hexToBytes(t, tt.hexInput)) + // verify that the conversion is correct + decStr := decimalBytesToString(inputBytes, tt.inputScale) + if tt.expected == "" { + tt.expected = strings.TrimPrefix(tt.input, "+") + } + assert.Equal(t, tt.input, decStr) + + newBytes, newPrecision, newScale, err := modifyDecimalPrecisionAndScale(inputBytes, tt.inputScale, tt.targetPrecision, tt.targetScale) + if tt.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.targetPrecision, newPrecision) + assert.Equal(t, tt.targetScale, newScale) + if tt.hexWant2 == "" { + tt.hexWant2 = tt.hexInput + } + assert.Equal(t, hexToBytes(t, tt.hexWant2), newBytes[:]) + decStr = decimalBytesToString(newBytes, newScale) + assert.Equal(t, tt.expected, decStr) + }) + } +} diff --git a/expr/literals.go b/expr/literals.go index ac850e0..9212782 100644 --- a/expr/literals.go +++ b/expr/literals.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/hex" "fmt" + "math" "reflect" "strings" "time" @@ -32,6 +33,10 @@ type nestedLiteral interface { StructLiteralValue | ListLiteralValue } +type WithTypeLiteral interface { + WithType(types.Type) (Literal, error) +} + // Easy type aliases for multi-value types that also // saves us having to create new types / new objects at runtime // when getting them from protobuf. @@ -318,6 +323,17 @@ func (t *NestedLiteral[T]) Visit(VisitFunc) Expression { } func (*NestedLiteral[T]) IsScalar() bool { return true } +func (t *NestedLiteral[T]) WithType(newType types.Type) (Literal, error) { + switch newType.(type) { + case *types.ListType: + return &NestedLiteral[ListLiteralValue]{ + Value: ListLiteralValue(t.Value), + Type: newType, + }, nil + } + return nil, fmt.Errorf("invalid type %T for nested literal", newType) +} + // MapLiteral is represented as a slice of Key/Value structs consisting // of other literals. type MapLiteral struct { @@ -453,6 +469,22 @@ func (t *ByteSliceLiteral[T]) ToProtoFuncArg() *proto.FunctionArgument { func (t *ByteSliceLiteral[T]) Visit(VisitFunc) Expression { return t } func (*ByteSliceLiteral[T]) IsScalar() bool { return true } +func (t *ByteSliceLiteral[T]) WithType(newType types.Type) (Literal, error) { + switch newType.(type) { + case *types.FixedBinaryType: + return &ByteSliceLiteral[types.FixedBinary]{ + Value: types.FixedBinary(t.Value), + Type: newType, + }, nil + case *types.UUIDType: + return &ByteSliceLiteral[types.UUID]{ + Value: types.UUID(t.Value), + Type: newType, + }, nil + } + return nil, fmt.Errorf("byte slice literal withType is not supported for %T ", newType) +} + // ProtoLiteral is a literal that is represented using its protobuf // message type such as a Decimal or UserDefinedType. type ProtoLiteral struct { @@ -460,6 +492,22 @@ type ProtoLiteral struct { Type types.Type } +func (t *ProtoLiteral) WithType(newType types.Type) (Literal, error) { + switch typ := newType.(type) { + case *types.DecimalType: + return newDecimalWithType(t, typ) + case *types.VarCharType: + return newVarCharWithType(t, typ) + case *types.PrecisionTimestampType: + return newPrecisionTimestampWithType(t, typ) + case *types.PrecisionTimestampTzType: + return newPrecisionTimestampTzWithType(t, typ) + case *types.IntervalDayType: + return newIntervalDayWithType(t, typ) + } + return nil, fmt.Errorf("protoLiteral withType is not supported for %T ", newType) +} + func (t *ProtoLiteral) ValueString() string { switch literalType := t.Type.(type) { case *types.PrecisionTimestampType: @@ -498,27 +546,16 @@ func (t *ProtoLiteral) IsoValueString() string { case *types.IntervalDayType: x, _ := t.Value.(*proto.Expression_Literal_IntervalDayToSecond) // Validity is required by construction. - seconds := x.GetSeconds() - minutes := seconds / 60 - hours := minutes / 60 - seconds = seconds % 60 - minutes = minutes % 60 sb := strings.Builder{} sb.WriteString("P") if x.GetDays() > 0 { sb.WriteString(fmt.Sprintf("%dD", x.GetDays())) } - if minutes > 0 || seconds > 0 { + if x.GetSeconds() > 0 || x.GetSubseconds() > 0 { sb.WriteString("T") - if hours > 0 { - sb.WriteString(fmt.Sprintf("%dH", hours)) - } - if minutes > 0 { - sb.WriteString(fmt.Sprintf("%dM", minutes)) - } - if seconds > 0 { - sb.WriteString(fmt.Sprintf("%dS", seconds)) - } + duration := time.Duration(x.GetSeconds()) * time.Second + duration += types.SubSecondsToDuration(x.GetSubseconds(), literalType.Precision) + sb.WriteString(strings.ToUpper(duration.String())) } return sb.String() } @@ -625,6 +662,72 @@ func (t *ProtoLiteral) ToProtoFuncArg() *proto.FunctionArgument { func (t *ProtoLiteral) Visit(VisitFunc) Expression { return t } func (*ProtoLiteral) IsScalar() bool { return true } +func newDecimalWithType(literal *ProtoLiteral, decType *types.DecimalType) (Literal, error) { + litType, ok := literal.GetType().(*types.DecimalType) + if !ok { + return nil, fmt.Errorf("literal type is not decimal") + } + inDecimalBytes := [16]byte(literal.Value.([]byte)) + decimalBytes, precision, scale, err := modifyDecimalPrecisionAndScale(inDecimalBytes, litType.Scale, decType.Precision, decType.Scale) + if err != nil { + return nil, err + } + return NewLiteral[*types.Decimal](&types.Decimal{Value: decimalBytes[:16], Precision: precision, Scale: scale}, decType.GetNullability() == types.NullabilityNullable) +} + +func newVarCharWithType(literal *ProtoLiteral, vcharType *types.VarCharType) (Literal, error) { + if _, ok := literal.GetType().(*types.VarCharType); !ok { + return nil, fmt.Errorf("literal type is not varchar") + } + if len(literal.Value.(string)) > int(vcharType.GetLength()) { + return nil, fmt.Errorf("varchar litearl value length is greater than type length") + } + return &ProtoLiteral{Value: literal.Value, Type: vcharType}, nil +} + +func newPrecisionTimestampWithType(literal *ProtoLiteral, ptsType *types.PrecisionTimestampType) (Literal, error) { + if litType, ok := literal.GetType().(*types.PrecisionTimestampType); ok { + value := types.GetTimeValueByPrecision(types.Timestamp(literal.Value.(int64)).ToPrecisionTime(litType.Precision), ptsType.Precision) + return &ProtoLiteral{Value: value, Type: ptsType}, nil + } + return nil, fmt.Errorf("literal type is not precision timestamp") +} + +func newPrecisionTimestampTzWithType(literal *ProtoLiteral, ptstzType *types.PrecisionTimestampTzType) (Literal, error) { + if litType, ok := literal.GetType().(*types.PrecisionTimestampTzType); ok { + value := types.GetTimeValueByPrecision(types.Timestamp(literal.Value.(int64)).ToPrecisionTime(litType.Precision), ptstzType.Precision) + return &ProtoLiteral{Value: value, Type: ptstzType}, nil + } + return nil, fmt.Errorf("literal type is not precision timestamp tz") +} + +func newIntervalDayWithType(literal *ProtoLiteral, intervalDayType *types.IntervalDayType) (Literal, error) { + if _, ok := literal.GetType().(*types.IntervalDayType); ok { + intervalValue := literal.Value.(*proto.Expression_Literal_IntervalDayToSecond) + precisionDiff := intervalValue.GetPrecision() - intervalDayType.Precision.ToProtoVal() + ss := intervalValue.GetSubseconds() + if precisionDiff != 0 { + factor := int64(math.Pow10(int(math.Abs(float64(precisionDiff))))) + if precisionDiff > 0 { + ss /= factor + } else { + ss *= factor + } + } + return &ProtoLiteral{ + Value: &types.IntervalDayToSecond{ + Days: intervalValue.GetDays(), + Seconds: intervalValue.GetSeconds(), + Subseconds: ss, + PrecisionMode: &proto.Expression_Literal_IntervalDayToSecond_Precision{ + Precision: intervalDayType.Precision.ToProtoVal(), + }, + }, Type: intervalDayType, + }, nil + } + return nil, fmt.Errorf("literal type is not interval day") +} + func getNullability(nullable bool) types.Nullability { if nullable { return types.NullabilityNullable diff --git a/expr/literals_test.go b/expr/literals_test.go new file mode 100644 index 0000000..49ec7ad --- /dev/null +++ b/expr/literals_test.go @@ -0,0 +1,212 @@ +package expr_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/v3/expr" + "github.com/substrait-io/substrait-go/v3/literal" + "github.com/substrait-io/substrait-go/v3/proto" + "github.com/substrait-io/substrait-go/v3/types" +) + +func TestNewDecimalWithType(t *testing.T) { + tests := []struct { + name string + precision int32 + scale int32 + decType *types.DecimalType + expStr string + expectedToFail bool + }{ + {"123.45", 5, 2, &types.DecimalType{Nullability: types.NullabilityRequired, Precision: 10, Scale: 5}, "123.45000", false}, + {"12345.678", 8, 3, &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 5}, "12345.67800", false}, + {"12345", 5, 0, &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 3, Scale: 2}, "", true}, + {"12345.888", 8, 3, &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 7, Scale: 3}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lit, err := literal.NewDecimalFromString(tt.name) + require.NoError(t, err) + got, err := lit.(*expr.ProtoLiteral).WithType(tt.decType) + if tt.expectedToFail { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.expStr, got.ValueString()) + }) + } +} + +func TestNewFixedLenWithType(t *testing.T) { + tests := []struct { + name string + inputType types.Type + wantErr bool + }{ + {"abc", &types.VarCharType{Length: 5, Nullability: types.NullabilityRequired}, false}, + {"abcde", &types.VarCharType{Length: 3, Nullability: types.NullabilityRequired}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input, _ := literal.NewVarChar(tt.name) + got, err := input.(*expr.ProtoLiteral).WithType(tt.inputType.(*types.VarCharType)) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.inputType, got.GetType()) + }) + } +} + +func TestNewPrecisionTimestampWithType(t *testing.T) { + tests := []struct { + name string + inputPrecision types.TimePrecision + inputType *types.PrecisionTimestampType + want expr.Literal + wantErr bool + }{ + {"1991-01-01T01:02:03.456", 3, &types.PrecisionTimestampType{Precision: 3, Nullability: types.NullabilityNullable}, nil, false}, + {"1991-01-01T01:02:03.456", 3, &types.PrecisionTimestampType{Precision: 6, Nullability: types.NullabilityNullable}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lit, err := literal.NewPrecisionTimestampFromString(tt.inputPrecision, tt.name) + require.NoError(t, err) + got, err := lit.(*expr.ProtoLiteral).WithType(tt.inputType) + if tt.wantErr { + require.NoError(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.inputType, got.GetType()) + assert.Equal(t, tt.name, got.(types.IsoValuePrinter).IsoValueString()) + }) + } +} + +func TestNewPrecisionTimestampTzWithType(t *testing.T) { + tests := []struct { + name string + inputPrecision types.TimePrecision + inputType types.PrecisionTimestampType + expLiteralString string + wantErr bool + }{ + {"1991-01-01T01:02:03.456+05:30", 3, types.PrecisionTimestampType{Precision: 3, Nullability: types.NullabilityNullable}, "1990-12-31T19:32:03.456+00:00", false}, + {"1991-01-01T01:02:03.456+05:30", 3, types.PrecisionTimestampType{Precision: 6, Nullability: types.NullabilityRequired}, "1990-12-31T19:32:03.456+00:00", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lit, err := literal.NewPrecisionTimestampTzFromString(tt.inputPrecision, tt.name) + require.NoError(t, err) + inputType := &types.PrecisionTimestampTzType{PrecisionTimestampType: tt.inputType} + got, err := lit.(*expr.ProtoLiteral).WithType(inputType) + if tt.wantErr { + require.NoError(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, inputType, got.GetType()) + assert.Equal(t, tt.expLiteralString, got.(types.IsoValuePrinter).IsoValueString()) + }) + } +} + +func TestNewIntervalDayWithType(t *testing.T) { + tests := []struct { + name string + inputType *types.IntervalDayType + expLiteralString string + expSubSeconds int64 + wantErr bool + }{ + {"PT23H59M59.999S", &types.IntervalDayType{Precision: 3, Nullability: types.NullabilityNullable}, "PT23H59M59.999S", 999, false}, + {"PT23H59M59.999S", &types.IntervalDayType{Precision: 2, Nullability: types.NullabilityNullable}, "PT23H59M59.99S", 99, false}, + {"PT23H59M59.999S", &types.IntervalDayType{Precision: 6, Nullability: types.NullabilityRequired}, "PT23H59M59.999S", 999000, false}, + {"PT23H59M59.999S", &types.IntervalDayType{Precision: 9, Nullability: types.NullabilityRequired}, "PT23H59M59.999S", 999000000, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lit, err := literal.NewIntervalDaysToSecondFromString(tt.name) + require.NoError(t, err) + got, err := lit.(*expr.ProtoLiteral).WithType(tt.inputType) + if tt.wantErr { + require.NoError(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.inputType, got.GetType()) + assert.Equal(t, tt.expLiteralString, got.(types.IsoValuePrinter).IsoValueString()) + iday := got.(*expr.ProtoLiteral).Value.(*proto.Expression_Literal_IntervalDayToSecond) + assert.Equal(t, tt.inputType.Precision.ToProtoVal(), iday.GetPrecision()) + assert.Equal(t, tt.expSubSeconds, iday.GetSubseconds()) + }) + } +} + +func TestProtoLiteral_WithType(t1 *testing.T) { + dec123, _ := literal.NewDecimalFromString("123.45") + iday, _ := literal.NewIntervalDaysToSecondFromString("PT23H59M59.999S") + pts, _ := literal.NewPrecisionTimestampFromString(3, "1991-01-01T01:02:03.456") + ptstz, _ := literal.NewPrecisionTimestampTzFromString(3, "1991-01-01T01:02:03.456") + vchar, _ := literal.NewVarChar("sun") + tests := []struct { + name string + protoLiteral *expr.ProtoLiteral + newType types.Type + want expr.Literal + wantErr assert.ErrorAssertionFunc + }{ + {"Decimal", dec123.(*expr.ProtoLiteral), &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 5}, nil, assert.NoError}, + {"IntervalDay", iday.(*expr.ProtoLiteral), &types.IntervalDayType{Precision: 3, Nullability: types.NullabilityNullable}, nil, assert.NoError}, + {"PrecisionTimestamp", pts.(*expr.ProtoLiteral), &types.PrecisionTimestampType{Precision: 3, Nullability: types.NullabilityNullable}, nil, assert.NoError}, + {"PrecisionTimestampTz", ptstz.(*expr.ProtoLiteral), &types.PrecisionTimestampTzType{PrecisionTimestampType: types.PrecisionTimestampType{Precision: 3, Nullability: types.NullabilityNullable}}, nil, assert.NoError}, + {"VarChar", vchar.(*expr.ProtoLiteral), &types.VarCharType{Length: 3, Nullability: types.NullabilityNullable}, nil, assert.NoError}, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + got, err := tt.protoLiteral.WithType(tt.newType) + if !tt.wantErr(t1, err, fmt.Sprintf("WithType(%v)", tt.newType)) { + return + } + assert.Equalf(t1, tt.newType, got.GetType(), "WithType(%v)", tt.newType) + }) + } +} + +func TestByteSliceLiteral_WithType(t1 *testing.T) { + fbin := expr.NewByteSliceLiteral[[]byte]([]byte{0x01, 0x02, 0x03}, false) + uuid := expr.NewByteSliceLiteral[types.UUID]([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, false) + + list := expr.NewNestedLiteral(expr.ListLiteralValue{ + literal.NewString("sun"), literal.NewString("moon"), literal.NewString("mars"), + }, false) + type testCase struct { + name string + t expr.WithTypeLiteral + newType types.Type + want expr.Literal + wantErr assert.ErrorAssertionFunc + } + tests := []testCase{ + {"FixedBinary", fbin, &types.FixedBinaryType{Length: 3, Nullability: types.NullabilityNullable}, nil, assert.NoError}, + {"UUID", uuid, &types.UUIDType{Nullability: types.NullabilityNullable}, nil, assert.NoError}, + {"List", list.(expr.WithTypeLiteral), &types.ListType{Type: &types.StringType{Nullability: types.NullabilityNullable}, Nullability: types.NullabilityNullable}, nil, assert.NoError}, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + got, err := tt.t.WithType(tt.newType) + if !tt.wantErr(t1, err, fmt.Sprintf("WithType(%v)", tt.newType)) { + return + } + assert.Equalf(t1, tt.newType, got.GetType(), "WithType(%v)", tt.newType) + }) + } +} diff --git a/extensions/variants_test.go b/extensions/variants_test.go index 868e33d..4356604 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -216,7 +216,8 @@ func TestMatchWithSyncParams(t *testing.T) { funcType parser2.TestFuncType numTests int }{ - {"tests/cases/arithmetic_decimal/bitwise_or.test", parser2.ScalarFuncType, 14}, + // TODO enable bitwise_or.test after fixing the testcase file + //{"tests/cases/arithmetic_decimal/bitwise_or.test", parser2.ScalarFuncType, 14}, {"tests/cases/arithmetic_decimal/bitwise_xor.test", parser2.ScalarFuncType, 14}, {"tests/cases/arithmetic_decimal/bitwise_and.test", parser2.ScalarFuncType, 14}, {"tests/cases/arithmetic_decimal/sqrt_decimal.test", parser2.ScalarFuncType, 14}, diff --git a/literal/utils.go b/literal/utils.go index 5ad81a6..fb602a5 100644 --- a/literal/utils.go +++ b/literal/utils.go @@ -323,7 +323,7 @@ func NewDecimalFromString(value string) (expr.Literal, error) { // NewPrecisionTimestampFromTime creates a new PrecisionTimestamp literal from a time.Time timestamp value with given precision. func NewPrecisionTimestampFromTime(precision types.TimePrecision, tm time.Time) (expr.Literal, error) { - return NewPrecisionTimestamp(precision, getTimeValueByPrecision(tm, precision)) + return NewPrecisionTimestamp(precision, types.GetTimeValueByPrecision(tm, precision)) } // NewPrecisionTimestamp creates a new PrecisionTimestamp literal with given precision and value. @@ -346,7 +346,7 @@ func NewPrecisionTimestampFromString(precision types.TimePrecision, value string // NewPrecisionTimestampTzFromTime creates a new PrecisionTimestampTz literal from a time.Time timestamp value with given precision. func NewPrecisionTimestampTzFromTime(precision types.TimePrecision, tm time.Time) (expr.Literal, error) { - return NewPrecisionTimestampTz(precision, getTimeValueByPrecision(tm, precision)) + return NewPrecisionTimestampTz(precision, types.GetTimeValueByPrecision(tm, precision)) } // NewPrecisionTimestampTz creates a new PrecisionTimestampTz literal with given precision and value. @@ -367,33 +367,6 @@ func NewPrecisionTimestampTzFromString(precision types.TimePrecision, value stri return NewPrecisionTimestampTzFromTime(precision, tm) } -func getTimeValueByPrecision(tm time.Time, precision types.TimePrecision) int64 { - switch precision { - case types.PrecisionSeconds: - return tm.Unix() - case types.PrecisionDeciSeconds: - return tm.UnixMilli() / 100 - case types.PrecisionCentiSeconds: - return tm.UnixMilli() / 10 - case types.PrecisionMilliSeconds: - return tm.UnixMilli() - case types.PrecisionEMinus4Seconds: - return tm.UnixMicro() / 100 - case types.PrecisionEMinus5Seconds: - return tm.UnixMicro() / 10 - case types.PrecisionMicroSeconds: - return tm.UnixMicro() - case types.PrecisionEMinus7Seconds: - return tm.UnixNano() / 100 - case types.PrecisionEMinus8Seconds: - return tm.UnixNano() / 10 - case types.PrecisionNanoSeconds: - return tm.UnixNano() - default: - panic(fmt.Sprintf("unknown TimePrecision %v", precision)) - } -} - func NewList(elements []expr.Literal) (expr.Literal, error) { if len(elements) == 0 { return nil, fmt.Errorf("empty list literal") diff --git a/testcases/parser/nodes.go b/testcases/parser/nodes.go index e2e5417..01f1c0b 100644 --- a/testcases/parser/nodes.go +++ b/testcases/parser/nodes.go @@ -78,6 +78,28 @@ func (c *CaseLiteral) AsAggregateArgumentString() string { return c.Value.ValueString() + "::" + c.Type.String() } +// updateLiteralType updates the type of the literal CaseLiteral.Value to use the CaseLiteral.Type +// Parser creates a literal with a type using existing util functions. +// For ParameterizedTypes utils functions use minimum required values for the parameters. +// This function changes the type to use requested type, so that the function invocation object is created correctly. +func (c *CaseLiteral) updateLiteralType() error { + if len(c.Type.GetParameters()) == 0 { + return nil + } + switch proLit := c.Value.(type) { + case *expr.NullLiteral: + return nil + case expr.WithTypeLiteral: + lit, err := proLit.WithType(c.Type) + if err != nil { + return err + } + c.Value = lit + return nil + } + return fmt.Errorf("literal type %T is not handled to update the type", c.Value) +} + type TestFileHeader struct { Version string FuncType TestFuncType diff --git a/testcases/parser/parse_test.go b/testcases/parser/parse_test.go index 2ed8bcf..e15ec43 100644 --- a/testcases/parser/parse_test.go +++ b/testcases/parser/parse_test.go @@ -113,12 +113,18 @@ power(-1::dec, 0.5::dec<38,1>) [complex_number_result:NAN] = nan::fp64 assert.Equal(t, "power", testFile.TestCases[0].FuncName) assert.Equal(t, "power", testFile.TestCases[1].FuncName) assert.Equal(t, "power", testFile.TestCases[2].FuncName) - dec8, _ := literal.NewDecimalFromString("8") - dec2, _ := literal.NewDecimalFromString("2") - dec1, _ := literal.NewDecimalFromString("1.0") - decMinus1, _ := literal.NewDecimalFromString("-1") - decMinus1Point0, _ := literal.NewDecimalFromString("-1.0") - decPoint5, _ := literal.NewDecimalFromString("0.5") + dec8Value, _ := literal.NewDecimalFromString("8") + dec2Value, _ := literal.NewDecimalFromString("2") + dec1Value, _ := literal.NewDecimalFromString("1.0") + decMinus1Value, _ := literal.NewDecimalFromString("-1") + decMinus1Point0Value, _ := literal.NewDecimalFromString("-1.0") + decPoint5Value, _ := literal.NewDecimalFromString("0.5") + dec8, _ := dec8Value.(expr.WithTypeLiteral).WithType(&types.DecimalType{Precision: 38, Scale: 0, Nullability: types.NullabilityRequired}) + dec2, _ := dec2Value.(expr.WithTypeLiteral).WithType(&types.DecimalType{Precision: 38, Scale: 0, Nullability: types.NullabilityRequired}) + dec1, _ := dec1Value.(expr.WithTypeLiteral).WithType(&types.DecimalType{Precision: 38, Scale: 5, Nullability: types.NullabilityRequired}) + decMinus1, _ := decMinus1Value.(expr.WithTypeLiteral).WithType(&types.DecimalType{Precision: 38, Scale: 0, Nullability: types.NullabilityRequired}) + decMinus1Point0, _ := decMinus1Point0Value.(expr.WithTypeLiteral).WithType(&types.DecimalType{Precision: 38, Scale: 5, Nullability: types.NullabilityRequired}) + decPoint5, _ := decPoint5Value.(expr.WithTypeLiteral).WithType(&types.DecimalType{Precision: 38, Scale: 1, Nullability: types.NullabilityRequired}) f6464 := literal.NewFloat64(64) f641 := literal.NewFloat64(1) assert.Equal(t, dec8, testFile.TestCases[0].Args[0].Value) @@ -142,9 +148,9 @@ func TestParseTestWithVariousTypes(t *testing.T) { {testCaseStr: "f2(1.0::fp32, 2.0::fp64) = -7.0::fp32", expTestStr: "f2(1::fp32, 2::fp64) = -7::fp32"}, {testCaseStr: "f3('a'::str, 'b'::string) = 'c'::str", expTestStr: "f3('a'::string, 'b'::string) = 'c'::string"}, {testCaseStr: "f4(false::bool, true::boolean) = false::bool", expTestStr: "f4(false::boolean, true::boolean) = false::boolean"}, - {testCaseStr: "f5(1.1::dec, 2.2::decimal) = 3.3::dec", expTestStr: "f5(1.1::decimal<38,0>, 2.2::decimal<38,0>) = 3.3::decimal<38,0>"}, - {testCaseStr: "f6(1.1::dec<38,10>, 2.2::dec<38,10>) = 3.3::dec<38,10>", expTestStr: "f6(1.1::decimal<38,10>, 2.2::decimal<38,10>) = 3.3::decimal<38,10>"}, - {testCaseStr: "f7(1.1::dec<38,10>, 2.2::decimal<38,10>) = 3.3::decimal<38,10>", expTestStr: "f7(1.1::decimal<38,10>, 2.2::decimal<38,10>) = 3.3::decimal<38,10>"}, + {testCaseStr: "f5(1::dec, 2::decimal) = 3::dec", expTestStr: "f5(1::decimal<38,0>, 2::decimal<38,0>) = 3::decimal<38,0>"}, + {testCaseStr: "f6(1.1::dec<38,10>, 2.2::dec<38,10>) = 3.3::dec<38,10>", expTestStr: "f6(1.1000000000::decimal<38,10>, 2.2000000000::decimal<38,10>) = 3.3000000000::decimal<38,10>"}, + {testCaseStr: "f7(1.1::dec<38,1>, 2.2::decimal<38,1>) = 3.3::decimal<38,1>", expTestStr: "f7(1.1::decimal<38,1>, 2.2::decimal<38,1>) = 3.3::decimal<38,1>"}, {testCaseStr: "f8('1991-01-01'::date) = '2001-01-01'::date"}, {testCaseStr: "f8('13:01:01.2345678'::time) = 123456::i64", expTestStr: "f8('13:01:01.234567'::time) = 123456::i64"}, {testCaseStr: "f8('13:01:01.234'::time) = 123::i32", expTestStr: "f8('13:01:01.234000'::time) = 123::i32"}, @@ -201,6 +207,7 @@ func checkNullability(t *testing.T, lit expr.Literal, argType types.Type) { } else { assert.Equal(t, types.NullabilityNullable, argType.GetNullability()) } + assert.Equal(t, argType, lit.GetType()) } func TestParseStringTestCases(t *testing.T) { @@ -610,10 +617,11 @@ func TestParseAggregateTestWithVariousTypes(t *testing.T) { {testCaseStr: "f4((false, true)::boolean) = false::bool", expTestStr: "f4((false, true)::boolean) = false::boolean"}, {testCaseStr: "f5((1.1, 2.2)::fp32) = 3.3::fp32"}, {testCaseStr: "f5((1.1, 2.2)::fp64) = 3.3::fp64"}, - {testCaseStr: "f5((1.1, 2.2)::decimal) = 3.3::dec", expTestStr: "f5((1.1, 2.2)::decimal<38,0>) = 3.3::decimal<38,0>"}, - {testCaseStr: "f6((1.1, 2.2)::dec<38,10>) = 3.3::dec<38,10>", expTestStr: "f6((1.1, 2.2)::decimal<38,10>) = 3.3::decimal<38,10>"}, - {testCaseStr: "f7((1.0, 2)::decimal<38,0>) = 3.0::decimal<38,0>"}, - {testCaseStr: "f6((1.1, 2.2, null)::dec?<38,10>) = 3.3::dec<38,10>", expTestStr: "f6((1.1, 2.2, null)::decimal?<38,10>) = 3.3::decimal<38,10>"}, + {testCaseStr: "f5((1, 2)::decimal) = 3::dec", expTestStr: "f5((1, 2)::decimal<38,0>) = 3::decimal<38,0>"}, + {testCaseStr: "f5((1.1, 2.2)::dec<38,1>) = 3.3::dec<38,1>", expTestStr: "f5((1.1, 2.2)::decimal<38,1>) = 3.3::decimal<38,1>"}, + {testCaseStr: "f6((1.1, 2.2)::dec<38,10>) = 3.3::dec<38,10>", expTestStr: "f6((1.1, 2.2)::decimal<38,10>) = 3.3000000000::decimal<38,10>"}, + {testCaseStr: "f7((1.0, 2)::decimal<38,0>) = 3::decimal<38,0>"}, + {testCaseStr: "f6((1.1, 2.2, null)::dec?<38,10>) = 3.3::dec<38,10>", expTestStr: "f6((1.1, 2.2, null)::decimal?<38,10>) = 3.3000000000::decimal<38,10>"}, {testCaseStr: "f8(('1991-01-01', '1991-02-02')::date) = '2001-01-01'::date"}, {testCaseStr: "f8(('13:01:01.2345678', '14:01:01.333')::time) = 123456::i64", expTestStr: "f8(('13:01:01.234567', '14:01:01.333000')::time) = 123456::i64"}, {testCaseStr: "f8(('1991-01-01T01:02:03.456', '1991-01-01T00:00:00')::timestamp) = '1991-01-01T22:33:44'::ts", expTestStr: "f8(('1991-01-01T01:02:03.456', '1991-01-01T00:00:00')::timestamp) = '1991-01-01T22:33:44'::timestamp"}, @@ -640,6 +648,7 @@ count(t1.col0) = 4::fp64`, expTestStr: "(('cat'), ('bat'), ('rat'), (null)) coun {testCaseStr: "f20(('abcd', 'ef')::fbin<9>) = Null::fbin<9>", expTestStr: "f20(('abcd', 'ef')::fixedbinary<9>) = null::fixedbinary?<9>"}, {testCaseStr: "f20(('abcd', 'ef')::varchar?<9>) = 'abcdef'::varchar<9>", expTestStr: "f20(('abcd', 'ef')::varchar?<9>) = 'abcdef'::varchar<9>"}, {testCaseStr: "f20(('abcd', null)::fixedchar?<9>) = Null::fixedchar<9>", expTestStr: "f20(('abcd', null)::fixedchar?<9>) = null::fixedchar?<9>"}, + {testCaseStr: "f20(('abcd', 'ef')::fixedbinary?<9>) = Null::fixedbinary<9>", expTestStr: "f20(('abcd', 'ef')::fixedbinary?<9>) = null::fixedbinary?<9>"}, {testCaseStr: "f35(('1991-01-01T01:02:03.456')::pts?<3>) = '1991-01-01T01:02:30.123123'::precision_timestamp<3>", expTestStr: "f35(('1991-01-01T01:02:03.456')::precision_timestamp?<3>) = '1991-01-01T01:02:30.123'::precision_timestamp<3>"}, {testCaseStr: "f36(('1991-01-01T01:02:03.456', '1991-01-01T01:02:30.123123')::precision_timestamp<3>) = 123456::i64"}, @@ -698,6 +707,9 @@ func TestLoadAllSubstraitTestFiles(t *testing.T) { case "tests/cases/datetime/extract.test": // TODO deal with enum arguments in testcase t.Skip("Skipping extract.test") + case "tests/cases/arithmetic_decimal/bitwise_or.test": + // TODO enable this after merging the PR with testcase fix + t.Skip("Skipping bitwise_or.test") } testFile, err := ParseTestCaseFileFromFS(got, filePath) @@ -717,10 +729,14 @@ func testGetFunctionInvocation(t *testing.T, tc *TestCase, reg *expr.ExtensionRe invocation, err := tc.GetScalarFunctionInvocation(reg, registry) require.NoError(t, err, "GetScalarFunctionInvocation failed with error in test case: %s", tc.CompoundFunctionName()) require.Equal(t, tc.ID().URI, invocation.ID().URI) + argTypes := invocation.GetArgTypes() + require.Equal(t, tc.GetArgTypes(), argTypes, "unexpected arg types in test case: %s", tc.CompoundFunctionName()) case AggregateFuncType: invocation, err := tc.GetAggregateFunctionInvocation(reg, registry) require.NoError(t, err, "GetAggregateFunctionInvocation failed with error in test case: %s", tc.CompoundFunctionName()) require.Equal(t, tc.ID().URI, invocation.ID().URI) + argTypes := invocation.GetArgTypes() + require.Equal(t, tc.GetArgTypes(), argTypes, "unexpected arg types in test case: %s", tc.CompoundFunctionName()) } } diff --git a/testcases/parser/visitor.go b/testcases/parser/visitor.go index 13bbf11..e0f2914 100644 --- a/testcases/parser/visitor.go +++ b/testcases/parser/visitor.go @@ -337,7 +337,11 @@ func (v *TestCaseVisitor) VisitTestCase(ctx *baseparser.TestCaseContext) interfa func (v *TestCaseVisitor) VisitArguments(ctx *baseparser.ArgumentsContext) interface{} { args := make([]*CaseLiteral, 0, len(ctx.AllArgument())) for _, argument := range ctx.AllArgument() { - args = append(args, v.Visit(argument).(*CaseLiteral)) + testArg := v.Visit(argument).(*CaseLiteral) + if err := testArg.updateLiteralType(); err != nil { + v.ErrorListener.ReportVisitError(fmt.Errorf("invalid argument %v", err)) + } + args = append(args, testArg) } return args } @@ -742,7 +746,11 @@ func (v *TestCaseVisitor) VisitResult(ctx *baseparser.ResultContext) interface{} if ctx.SubstraitError() != nil { return v.Visit(ctx.SubstraitError()) } - return v.Visit(ctx.Argument()).(*CaseLiteral) + result := v.Visit(ctx.Argument()).(*CaseLiteral) + if err := result.updateLiteralType(); err != nil { + v.ErrorListener.ReportVisitError(fmt.Errorf("invalid result: %v", err)) + } + return result } func (v *TestCaseVisitor) VisitSubstraitError(ctx *baseparser.SubstraitErrorContext) interface{} { diff --git a/types/precison_timestamp_types.go b/types/precison_timestamp_types.go index edfbc1a..916d3cf 100644 --- a/types/precison_timestamp_types.go +++ b/types/precison_timestamp_types.go @@ -5,6 +5,7 @@ package types import ( "fmt" "reflect" + "time" "github.com/substrait-io/substrait-go/v3/proto" ) @@ -31,6 +32,33 @@ func (m TimePrecision) ToProtoVal() int32 { return int32(m) } +func SubSecondsToDuration(subSeconds int64, precision TimePrecision) time.Duration { + switch precision { + case PrecisionSeconds: + return time.Duration(subSeconds) * time.Second + case PrecisionDeciSeconds: + return time.Duration(subSeconds) * time.Second / 10 + case PrecisionCentiSeconds: + return time.Duration(subSeconds) * time.Second / 100 + case PrecisionMilliSeconds: + return time.Duration(subSeconds) * time.Millisecond + case PrecisionEMinus4Seconds: + return time.Duration(subSeconds) * 100 * time.Microsecond + case PrecisionEMinus5Seconds: + return time.Duration(subSeconds) * 10 * time.Microsecond + case PrecisionMicroSeconds: + return time.Duration(subSeconds) * time.Microsecond + case PrecisionEMinus7Seconds: + return time.Duration(subSeconds) * 100 * time.Nanosecond + case PrecisionEMinus8Seconds: + return time.Duration(subSeconds) * 10 * time.Nanosecond + case PrecisionNanoSeconds: + return time.Duration(subSeconds) * time.Nanosecond + default: + panic(fmt.Sprintf("invalid precision %d", precision)) + } +} + func ProtoToTimePrecision(val int32) (TimePrecision, error) { if val < PrecisionSeconds.ToProtoVal() || val > PrecisionNanoSeconds.ToProtoVal() { return PrecisionUnknown, fmt.Errorf("invalid TimePrecision value %d", val) diff --git a/types/precison_timestamp_types_test.go b/types/precison_timestamp_types_test.go index 337f93f..fa860b2 100644 --- a/types/precison_timestamp_types_test.go +++ b/types/precison_timestamp_types_test.go @@ -5,6 +5,7 @@ package types import ( "fmt" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" @@ -122,3 +123,28 @@ func assertPrecisionTimeStampTzProto(t *testing.T, expectedPrecision TimePrecisi t.Errorf("precisionTimeStampTz proto didn't match, diff:\n%v", diff) } } + +func TestSubSecondsToDuration(t *testing.T) { + tests := []struct { + name string + subSeconds int64 + precision TimePrecision + want time.Duration + }{ + {"0.000000001s", 1, PrecisionNanoSeconds, time.Nanosecond}, + {"0.00000001s", 1, PrecisionEMinus8Seconds, time.Nanosecond * 10}, + {"0.0000001s", 1, PrecisionEMinus7Seconds, time.Nanosecond * 100}, + {"0.000001s", 1, PrecisionMicroSeconds, time.Microsecond}, + {"0.00001s", 1, PrecisionEMinus5Seconds, time.Microsecond * 10}, + {"0.0001s", 1, PrecisionEMinus4Seconds, time.Microsecond * 100}, + {"0.001s", 1, PrecisionMilliSeconds, time.Millisecond}, + {"0.01s", 1, PrecisionCentiSeconds, time.Millisecond * 10}, + {"0.1s", 1, PrecisionDeciSeconds, time.Millisecond * 100}, + {"1s", 1, PrecisionSeconds, time.Second}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, SubSecondsToDuration(tt.subSeconds, tt.precision), "SubSecondsToDuration(%v, %v)", tt.subSeconds, tt.precision) + }) + } +} diff --git a/types/types.go b/types/types.go index 36c9eec..de43a9a 100644 --- a/types/types.go +++ b/types/types.go @@ -1655,3 +1655,30 @@ func (t TimestampTz) ToIsoTimeString() string { tm := any(t).(TimeConverter).ToTime() return tm.UTC().Format("2006-01-02T15:04:05.999999999") } + +func GetTimeValueByPrecision(tm time.Time, precision TimePrecision) int64 { + switch precision { + case PrecisionSeconds: + return tm.Unix() + case PrecisionDeciSeconds: + return tm.UnixMilli() / 100 + case PrecisionCentiSeconds: + return tm.UnixMilli() / 10 + case PrecisionMilliSeconds: + return tm.UnixMilli() + case PrecisionEMinus4Seconds: + return tm.UnixMicro() / 100 + case PrecisionEMinus5Seconds: + return tm.UnixMicro() / 10 + case PrecisionMicroSeconds: + return tm.UnixMicro() + case PrecisionEMinus7Seconds: + return tm.UnixNano() / 100 + case PrecisionEMinus8Seconds: + return tm.UnixNano() / 10 + case PrecisionNanoSeconds: + return tm.UnixNano() + default: + panic(fmt.Sprintf("unknown TimePrecision %v", precision)) + } +} diff --git a/types/types_test.go b/types/types_test.go index 049528e..e6f7f74 100644 --- a/types/types_test.go +++ b/types/types_test.go @@ -5,6 +5,7 @@ package types_test import ( "fmt" "testing" + "time" "github.com/stretchr/testify/assert" . "github.com/substrait-io/substrait-go/v3/types" @@ -470,3 +471,29 @@ func TestMatchParameterizedNestedTypeResultMatch(t *testing.T) { func markNullable(t FuncDefArgType) FuncDefArgType { return t.SetNullability(NullabilityNullable) } + +func TestGetTimeValueByPrecision(t *testing.T) { + timeStr := "2021-08-10T15:01:05.123456789Z" + tests := []struct { + name string + precision TimePrecision + want int64 + }{ + {"PrecisionSeconds", PrecisionSeconds, 1628607665}, + {"PrecisionDeciSeconds", PrecisionDeciSeconds, 16286076651}, + {"PrecisionCentiSeconds", PrecisionCentiSeconds, 162860766512}, + {"PrecisionMilliSeconds", PrecisionMilliSeconds, 1628607665123}, + {"PrecisionEMinus4Seconds", PrecisionEMinus4Seconds, 16286076651234}, + {"PrecisionEMinus5Seconds", PrecisionEMinus5Seconds, 162860766512345}, + {"PrecisionMicroSeconds", PrecisionMicroSeconds, 1628607665123456}, + {"PrecisionEMinus7Seconds", PrecisionEMinus7Seconds, 16286076651234567}, + {"PrecisionEMinus8Seconds", PrecisionEMinus8Seconds, 162860766512345678}, + {"PrecisionNanoSeconds", PrecisionNanoSeconds, 1628607665123456789}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm, _ := time.Parse(time.RFC3339Nano, timeStr) + assert.Equalf(t, tt.want, GetTimeValueByPrecision(tm, tt.precision), "GetTimeValueByPrecision(%v, %v)", timeStr, tt.precision) + }) + } +}