Skip to content

Commit

Permalink
fix(Query): [Breaking] Return error for illegal math operations. (#7631)
Browse files Browse the repository at this point in the history
* Return error if operand for math functions are invalid
  • Loading branch information
rohanprasad authored Mar 31, 2021
1 parent 2e39f9c commit 0f59807
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 7 deletions.
66 changes: 59 additions & 7 deletions query/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,14 @@ func applyAdd(a, b, c *types.Val) error {
vBase := getValType(a)
switch vBase {
case INT:
c.Value = a.Value.(int64) + b.Value.(int64)
aVal, bVal := a.Value.(int64), b.Value.(int64)

if (aVal > 0 && bVal > math.MaxInt64-aVal) ||
(aVal < 0 && bVal < math.MinInt64-aVal) {
return ErrorIntOverflow
}

c.Value = aVal + bVal

case FLOAT:
c.Value = a.Value.(float64) + b.Value.(float64)
Expand All @@ -131,7 +138,14 @@ func applySub(a, b, c *types.Val) error {
vBase := getValType(a)
switch vBase {
case INT:
c.Value = a.Value.(int64) - b.Value.(int64)
aVal, bVal := a.Value.(int64), b.Value.(int64)

if (bVal < 0 && aVal > math.MaxInt64+bVal) ||
(bVal > 0 && aVal < math.MinInt64+bVal) {
return ErrorIntOverflow
}

c.Value = aVal - bVal

case FLOAT:
c.Value = a.Value.(float64) - b.Value.(float64)
Expand All @@ -146,7 +160,14 @@ func applyMul(a, b, c *types.Val) error {
vBase := getValType(a)
switch vBase {
case INT:
c.Value = a.Value.(int64) * b.Value.(int64)
aVal, bVal := a.Value.(int64), b.Value.(int64)
c.Value = aVal * bVal

if aVal == 0 || bVal == 0 {
return nil
} else if c.Value.(int64)/bVal != aVal {
return ErrorIntOverflow
}

case FLOAT:
c.Value = a.Value.(float64) * b.Value.(float64)
Expand All @@ -162,13 +183,13 @@ func applyDiv(a, b, c *types.Val) error {
switch vBase {
case INT:
if b.Value.(int64) == 0 {
return errors.Errorf("Division by zero")
return ErrorDivisionByZero
}
c.Value = a.Value.(int64) / b.Value.(int64)

case FLOAT:
if b.Value.(float64) == 0 {
return errors.Errorf("Division by zero")
return ErrorDivisionByZero
}
c.Value = a.Value.(float64) / b.Value.(float64)

Expand All @@ -183,13 +204,13 @@ func applyMod(a, b, c *types.Val) error {
switch vBase {
case INT:
if b.Value.(int64) == 0 {
return errors.Errorf("Module by zero")
return ErrorDivisionByZero
}
c.Value = a.Value.(int64) % b.Value.(int64)

case FLOAT:
if b.Value.(float64) == 0 {
return errors.Errorf("Module by zero")
return ErrorDivisionByZero
}
c.Value = math.Mod(a.Value.(float64), b.Value.(float64))

Expand All @@ -207,6 +228,11 @@ func applyPow(a, b, c *types.Val) error {
c.Tid = types.FloatID

case FLOAT:
// Fractional power of -ve numbers should not be returned.
if a.Value.(float64) < 0 &&
math.Abs(math.Ceil(b.Value.(float64))-b.Value.(float64)) > 0 {
return ErrorFractionalPower
}
c.Value = math.Pow(a.Value.(float64), b.Value.(float64))

case DEFAULT:
Expand All @@ -219,10 +245,20 @@ func applyLog(a, b, c *types.Val) error {
vBase := getValType(a)
switch vBase {
case INT:
if a.Value.(int64) < 0 || b.Value.(int64) < 0 {
return ErrorNegativeLog
} else if b.Value.(int64) == 1 {
return ErrorDivisionByZero
}
c.Value = math.Log(float64(a.Value.(int64))) / math.Log(float64(b.Value.(int64)))
c.Tid = types.FloatID

case FLOAT:
if a.Value.(float64) < 0 || b.Value.(float64) < 0 {
return ErrorNegativeLog
} else if b.Value.(float64) == 1 {
return ErrorDivisionByZero
}
c.Value = math.Log(a.Value.(float64)) / math.Log(b.Value.(float64))

case DEFAULT:
Expand Down Expand Up @@ -261,10 +297,16 @@ func applyLn(a, res *types.Val) error {
vBase := getValType(a)
switch vBase {
case INT:
if a.Value.(int64) < 0 {
return ErrorNegativeLog
}
res.Value = math.Log(float64(a.Value.(int64)))
res.Tid = types.FloatID

case FLOAT:
if a.Value.(float64) < 0 {
return ErrorNegativeLog
}
res.Value = math.Log(a.Value.(float64))

case DEFAULT:
Expand Down Expand Up @@ -293,6 +335,10 @@ func applyNeg(a, res *types.Val) error {
vBase := getValType(a)
switch vBase {
case INT:
// -ve of math.MinInt64 is evaluated as itself (due to overflow)
if a.Value.(int64) == math.MinInt64 {
return ErrorIntOverflow
}
res.Value = -a.Value.(int64)

case FLOAT:
Expand All @@ -308,10 +354,16 @@ func applySqrt(a, res *types.Val) error {
vBase := getValType(a)
switch vBase {
case INT:
if a.Value.(int64) < 0 {
return ErrorNegativeRoot
}
res.Value = math.Sqrt(float64(a.Value.(int64)))
res.Tid = types.FloatID

case FLOAT:
if a.Value.(float64) < 0 {
return ErrorNegativeRoot
}
res.Value = math.Sqrt(a.Value.(float64))

case DEFAULT:
Expand Down
8 changes: 8 additions & 0 deletions query/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ type mathTree struct {
Child []*mathTree
}

var (
ErrorIntOverflow = errors.New("Integer overflow")
ErrorDivisionByZero = errors.New("Division by zero")
ErrorFractionalPower = errors.New("Fractional power of negative number")
ErrorNegativeLog = errors.New("Log of negative number")
ErrorNegativeRoot = errors.New("Root of negative number")
)

// processBinary handles the binary operands like
// +, -, *, /, %, max, min, logbase
func processBinary(mNode *mathTree) error {
Expand Down
202 changes: 202 additions & 0 deletions query/math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package query

import (
"math"
"testing"

"github.com/dgraph-io/dgraph/types"
Expand Down Expand Up @@ -227,6 +228,154 @@ func TestProcessBinary(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, tc.out, tc.in.Const)
}

errorTests := []struct {
name string
in *mathTree
err error
}{
{in: &mathTree{
Fn: "+",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(9223372036854775800)}},
{Const: types.Val{Tid: types.IntID, Value: int64(10)}},
}},
err: ErrorIntOverflow,
name: "Addition integer overflow",
},
{in: &mathTree{
Fn: "+",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(-10)}},
{Const: types.Val{Tid: types.IntID, Value: int64(-9223372036854775800)}},
}},
err: ErrorIntOverflow,
name: "Addition integer underflow",
},
{in: &mathTree{
Fn: "-",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(9223372036854775800)}},
{Const: types.Val{Tid: types.IntID, Value: int64(-10)}},
}},
err: ErrorIntOverflow,
name: "Subtraction integer overflow",
},
{in: &mathTree{
Fn: "-",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(-10)}},
{Const: types.Val{Tid: types.IntID, Value: int64(9223372036854775800)}},
}},
err: ErrorIntOverflow,
name: "Subtraction integer underflow",
},
{in: &mathTree{
Fn: "*",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(9223372036854775)}},
{Const: types.Val{Tid: types.IntID, Value: int64(10000)}},
}},
err: ErrorIntOverflow,
name: "Multiplication integer overflow",
},
{in: &mathTree{
Fn: "*",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(-10000)}},
{Const: types.Val{Tid: types.IntID, Value: int64(9223372036854775)}},
}},
err: ErrorIntOverflow,
name: "Multiplication integer underflow",
},
{in: &mathTree{
Fn: "/",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(23)}},
{Const: types.Val{Tid: types.IntID, Value: int64(0)}},
}},
err: ErrorDivisionByZero,
name: "Division int zero",
},
{in: &mathTree{
Fn: "/",
Child: []*mathTree{
{Const: types.Val{Tid: types.FloatID, Value: float64(23)}},
{Const: types.Val{Tid: types.FloatID, Value: float64(0)}},
}},
err: ErrorDivisionByZero,
name: "Division float zero",
},
{in: &mathTree{
Fn: "%",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(23)}},
{Const: types.Val{Tid: types.IntID, Value: int64(0)}},
}},
err: ErrorDivisionByZero,
name: "Modulo int zero",
},
{in: &mathTree{
Fn: "%",
Child: []*mathTree{
{Const: types.Val{Tid: types.FloatID, Value: float64(23)}},
{Const: types.Val{Tid: types.FloatID, Value: float64(0)}},
}},
err: ErrorDivisionByZero,
name: "Modulo float zero",
},
{in: &mathTree{
Fn: "pow",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(-2)}},
{Const: types.Val{Tid: types.FloatID, Value: float64(1.7)}},
}},
err: ErrorFractionalPower,
name: "Fractional negative power",
},
{in: &mathTree{
Fn: "logbase",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(-2)}},
{Const: types.Val{Tid: types.IntID, Value: int64(2)}},
}},
err: ErrorNegativeLog,
name: "Log negative integer numerator",
},
{in: &mathTree{
Fn: "logbase",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(2)}},
{Const: types.Val{Tid: types.IntID, Value: int64(-2)}},
}},
err: ErrorNegativeLog,
name: "Log negative integer denominator",
},
{in: &mathTree{
Fn: "logbase",
Child: []*mathTree{
{Const: types.Val{Tid: types.FloatID, Value: float64(-2)}},
{Const: types.Val{Tid: types.FloatID, Value: float64(2)}},
}},
err: ErrorNegativeLog,
name: "Log negative float numerator",
},
{in: &mathTree{
Fn: "logbase",
Child: []*mathTree{
{Const: types.Val{Tid: types.FloatID, Value: float64(2)}},
{Const: types.Val{Tid: types.FloatID, Value: float64(-2)}},
}},
err: ErrorNegativeLog,
name: "Log negative float denominator",
},
}

for _, tc := range errorTests {
t.Logf("Test %s", tc.name)
err := processBinary(tc.in)
require.EqualError(t, err, tc.err.Error())
}
}

func TestProcessUnary(t *testing.T) {
Expand Down Expand Up @@ -283,6 +432,59 @@ func TestProcessUnary(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, tc.out, tc.in.Const)
}

errorTests := []struct {
name string
in *mathTree
err error
}{
{in: &mathTree{
Fn: "ln",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(-2)}},
}},
err: ErrorNegativeLog,
name: "Negative int ln",
},
{in: &mathTree{
Fn: "ln",
Child: []*mathTree{
{Const: types.Val{Tid: types.FloatID, Value: float64(-2)}},
}},
err: ErrorNegativeLog,
name: "Negative float ln",
},
{in: &mathTree{
Fn: "u-",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(math.MinInt64)}},
}},
err: ErrorIntOverflow,
name: "Negation int overflow",
},
{in: &mathTree{
Fn: "sqrt",
Child: []*mathTree{
{Const: types.Val{Tid: types.IntID, Value: int64(-2)}},
}},
err: ErrorNegativeRoot,
name: "Negative int sqrt",
},
{in: &mathTree{
Fn: "sqrt",
Child: []*mathTree{
{Const: types.Val{Tid: types.FloatID, Value: float64(-2)}},
}},
err: ErrorNegativeRoot,
name: "Negative float sqrt",
},
}

for _, tc := range errorTests {
t.Logf("Test %s", tc.name)
err := processUnary(tc.in)
require.EqualError(t, err, tc.err.Error())
}
}

func TestProcessBinaryBoolean(t *testing.T) {
Expand Down

0 comments on commit 0f59807

Please sign in to comment.