diff --git a/contract/p/gnoswap/gnsmath/_helper_test.gno b/contract/p/gnoswap/gnsmath/_helper_test.gno deleted file mode 100644 index fd43cfb30..000000000 --- a/contract/p/gnoswap/gnsmath/_helper_test.gno +++ /dev/null @@ -1,39 +0,0 @@ -package gnsmath - -import ( - "testing" -) - -func shouldEQ(t *testing.T, got, expected interface{}) { - if got != expected { - t.Errorf("got %v, expected %v", got, expected) - } -} - -func shouldNEQ(t *testing.T, got, expected interface{}) { - if got == expected { - t.Errorf("got %v, didn't expected %v", got, expected) - } -} - -func shouldPanic(t *testing.T, f func()) { - defer func() { - if r := recover(); r == nil { - t.Errorf("expected panic") - } - }() - f() -} - -func shouldPanicWithMsg(t *testing.T, f func(), msg string) { - defer func() { - if r := recover(); r == nil { - t.Errorf("The code did not panic") - } else { - if r != msg { - t.Errorf("excepted panic(%v), got(%v)", msg, r) - } - } - }() - f() -} diff --git a/contract/p/gnoswap/gnsmath/bit_math.gno b/contract/p/gnoswap/gnsmath/bit_math.gno index b679cd03a..14e9bcdea 100644 --- a/contract/p/gnoswap/gnsmath/bit_math.gno +++ b/contract/p/gnoswap/gnsmath/bit_math.gno @@ -4,29 +4,44 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) +var ( + msbShifts = []bitShift{ + {u256.MustFromDecimal(Q128), 128}, // 2^128 + {u256.MustFromDecimal(Q64), 64}, // 2^64 + {u256.NewUint(0x100000000), 32}, // 2^32 + {u256.NewUint(0x10000), 16}, // 2^16 + {u256.NewUint(0x100), 8}, // 2^8 + {u256.NewUint(0x10), 4}, // 2^4 + {u256.NewUint(0x4), 2}, // 2^2 + {u256.NewUint(0x2), 1}, // 2^1 + } + + lsbShifts = []bitShift{ + {u256.MustFromDecimal(MAX_UINT128), 128}, + {u256.MustFromDecimal(MAX_UINT64), 64}, + {u256.MustFromDecimal(MAX_UINT32), 32}, + {u256.MustFromDecimal(MAX_UINT16), 16}, + {u256.MustFromDecimal(MAX_UINT8), 8}, + {u256.NewUint(0xf), 4}, + {u256.NewUint(0x3), 2}, + {u256.NewUint(0x1), 1}, + } +) + type bitShift struct { bitPattern *u256.Uint shift uint } +// BitMathMostSignificantBit finds the highest set bit (0-based) in x. +// If x == 0, it panics. func BitMathMostSignificantBit(x *u256.Uint) uint8 { if x.IsZero() { - panic("BitMathMostSignificantBit: x should not be zero") - } - - shifts := []bitShift{ - {u256.MustFromDecimal(Q128), 128}, // 2^128 - {u256.MustFromDecimal(Q64), 64}, // 2^64 - {u256.NewUint(0x100000000), 32}, - {u256.NewUint(0x10000), 16}, - {u256.NewUint(0x100), 8}, - {u256.NewUint(0x10), 4}, - {u256.NewUint(0x4), 2}, - {u256.NewUint(0x2), 1}, + panic(errMSBZeroInput) } r := uint8(0) - for _, s := range shifts { + for _, s := range msbShifts { if x.Gte(s.bitPattern) { x = new(u256.Uint).Rsh(x, s.shift) r += uint8(s.shift) @@ -36,24 +51,15 @@ func BitMathMostSignificantBit(x *u256.Uint) uint8 { return r } +// BitMathLeastSignificantBit finds the lowest set bit (0-based) in x. +// If x == 0, it panics. func BitMathLeastSignificantBit(x *u256.Uint) uint8 { if x.IsZero() { - panic("BitMathLeastSignificantBit: x should not be zero") - } - - shifts := []bitShift{ - {u256.MustFromDecimal(MAX_UINT128), 128}, - {u256.MustFromDecimal(MAX_UINT64), 64}, - {u256.MustFromDecimal(MAX_UINT32), 32}, - {u256.MustFromDecimal(MAX_UINT16), 16}, - {u256.MustFromDecimal(MAX_UINT8), 8}, - {u256.NewUint(0xf), 4}, - {u256.NewUint(0x3), 2}, - {u256.NewUint(0x1), 1}, + panic(errLSBZeroInput) } r := uint8(255) - for _, s := range shifts { + for _, s := range lsbShifts { if new(u256.Uint).And(x, s.bitPattern).Gt(u256.Zero()) { r -= uint8(s.shift) } else { diff --git a/contract/p/gnoswap/gnsmath/bit_math_test.gno b/contract/p/gnoswap/gnsmath/bit_math_test.gno index 6f1fb71dc..ada8b9c3d 100644 --- a/contract/p/gnoswap/gnsmath/bit_math_test.gno +++ b/contract/p/gnoswap/gnsmath/bit_math_test.gno @@ -3,69 +3,65 @@ package gnsmath import ( "testing" + "gno.land/p/demo/uassert" + u256 "gno.land/p/gnoswap/uint256" ) func TestBitMathMostSignificantBit(t *testing.T) { t.Run("0", func(t *testing.T) { - shouldPanic( - t, - func() { - BitMathMostSignificantBit(u256.Zero()) - }, - ) + uassert.PanicsWithMessage(t, errMSBZeroInput.Error(), func() { + BitMathMostSignificantBit(u256.Zero()) + }) }) t.Run("1", func(t *testing.T) { - shouldEQ(t, BitMathMostSignificantBit(u256.One()), uint8(0)) + uassert.Equal(t, BitMathMostSignificantBit(u256.One()), uint8(0)) }) t.Run("2", func(t *testing.T) { - shouldEQ(t, BitMathMostSignificantBit(u256.NewUint(2)), uint8(1)) + uassert.Equal(t, BitMathMostSignificantBit(u256.NewUint(2)), uint8(1)) }) t.Run("all powers of 2", func(t *testing.T) { for i := 0; i < 256; i++ { num := u256.Zero() num.Lsh(u256.One(), uint(i)) - shouldEQ(t, BitMathMostSignificantBit(num), uint8(i)) + uassert.Equal(t, BitMathMostSignificantBit(num), uint8(i)) } }) t.Run("uint256(-1)", func(t *testing.T) { // BigNumber.from(2).pow(256).sub(1)) - shouldEQ(t, BitMathMostSignificantBit(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")), uint8(255)) + uassert.Equal(t, BitMathMostSignificantBit(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")), uint8(255)) }) } func TestBitMathLeastSignificantBit(t *testing.T) { t.Run("0", func(t *testing.T) { - shouldPanic( - t, - func() { - BitMathLeastSignificantBit(u256.Zero()) - }, - ) + uassert.PanicsWithMessage(t, errLSBZeroInput.Error(), func() { + BitMathLeastSignificantBit(u256.Zero()) + }) }) t.Run("1", func(t *testing.T) { - shouldEQ(t, BitMathLeastSignificantBit(u256.One()), uint8(0)) + uassert.Equal(t, BitMathLeastSignificantBit(u256.One()), uint8(0)) }) t.Run("2", func(t *testing.T) { - shouldEQ(t, BitMathLeastSignificantBit(u256.NewUint(2)), uint8(1)) + uassert.Equal(t, BitMathLeastSignificantBit(u256.NewUint(2)), uint8(1)) }) t.Run("all powers of 2", func(t *testing.T) { for i := 0; i < 256; i++ { num := u256.Zero() num.Lsh(u256.One(), uint(i)) - shouldEQ(t, BitMathLeastSignificantBit(num), uint8(i)) + uassert.Equal(t, BitMathLeastSignificantBit(num), uint8(i)) } }) t.Run("uint256(-1)", func(t *testing.T) { // BigNumber.from(2).pow(256).sub(1)) - shouldEQ(t, BitMathLeastSignificantBit(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")), uint8(0)) + uassert.Equal(t, BitMathLeastSignificantBit(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")), uint8(0)) }) } diff --git a/contract/p/gnoswap/gnsmath/erros.gno b/contract/p/gnoswap/gnsmath/erros.gno new file mode 100644 index 000000000..feae571eb --- /dev/null +++ b/contract/p/gnoswap/gnsmath/erros.gno @@ -0,0 +1,19 @@ +package gnsmath + +import ( + "errors" +) + +var ( + errInvalidPoolSqrtPrice = errors.New("invalid pool sqrt price calculation: product/amount != sqrtPX96 or numerator1 <= product") + errNextSqrtPriceOverflow = errors.New("nextSqrtPrice overflows uint160") + errSqrtPriceQuotientOverflow = errors.New("GetNextSqrtPriceFromAmount1RoundingDown sqrtPx96 + quotient overflow uint160") + errSqrtPriceExceedsQuotient = errors.New("sqrt price exceeds calculated quotient") + errSqrtPriceZero = errors.New("sqrtPX96 should not be zero") + errLiquidityZero = errors.New("liquidity should not be zero") + errSqrtRatioAX96NotPositive = errors.New("sqrtRatioAX96 must be greater than zero") + errAmount0DeltaOverflow = errors.New("SqrtPriceMathGetAmount0DeltaStr: overflow") + errAmount1DeltaOverflow = errors.New("SqrtPriceMathGetAmount1DeltaStr: overflow") + errMSBZeroInput = errors.New("input for MSB calculation should not be zero") + errLSBZeroInput = errors.New("input for LSB calculation should not be zero") +) diff --git a/contract/p/gnoswap/gnsmath/sqrt_price_math.gno b/contract/p/gnoswap/gnsmath/sqrt_price_math.gno index 5126bf611..ea9c5845d 100644 --- a/contract/p/gnoswap/gnsmath/sqrt_price_math.gno +++ b/contract/p/gnoswap/gnsmath/sqrt_price_math.gno @@ -5,11 +5,60 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) +const ( + Q96_RESOLUTION uint = 96 + Q160_RESOLUTION uint = 160 +) + var ( - Q96_RESOLUTION = uint(96) - Q160_RESOLUTION = uint(160) + q96 = u256.MustFromDecimal(Q96) + max160 = u256.MustFromDecimal(MAX_UINT160) ) +// getNextPriceAmount0Add calculates the next sqrt price when we are adding token0. +// Preserves the rounding-up logic. No in-place mutation of input arguments. +func getNextPriceAmount0Add( + sqrtPX96, liquidity, amount *u256.Uint, +) *u256.Uint { + numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION) + product := new(u256.Uint).Mul(amount, sqrtPX96) + + // overflow check + if new(u256.Uint).Div(product, amount).Eq(sqrtPX96) { + denominator := new(u256.Uint).Add(numerator1, product) + if denominator.Gte(numerator1) { + return u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator) + } + } + + divValue := new(u256.Uint).Div(numerator1, sqrtPX96) + addValue := new(u256.Uint).Add(divValue, amount) + + return u256.DivRoundingUp(numerator1, addValue) +} + +// getNextPriceAmount0Remove calculates the next sqrt price when we are removing token0. +// Preserves the rounding-up logic. No in-place mutation of input arguments. +func getNextPriceAmount0Remove( + sqrtPX96, liquidity, amount *u256.Uint, +) *u256.Uint { + numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION) + product := new(u256.Uint).Mul(amount, sqrtPX96) + + // Conditions must hold: product/amount == sqrtPX96 and numerator1 > product + if !new(u256.Uint).Div(product, amount).Eq(sqrtPX96) || !numerator1.Gt(product) { + panic(errInvalidPoolSqrtPrice) + } + + denominator := new(u256.Uint).Sub(numerator1, product) + nextSqrtPrice := u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator) + if nextSqrtPrice.Gt(max160) { + panic(errNextSqrtPriceOverflow) + } + + return nextSqrtPrice +} + // sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp calculates the next square root price // based on the amount of token0 added or removed from the pool. // NOTE: Always rounds up, because in the exact output case (increasing price) we need to move the price at least @@ -38,42 +87,66 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( amount *u256.Uint, add bool, ) *u256.Uint { - // we short circuit amount == 0 because the result is otherwise not guaranteed to equal the input price + // Shortcut: if no amount, return original price if amount.IsZero() { return sqrtPX96 } - numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION) - product := new(u256.Uint).Mul(amount, sqrtPX96) - if add { - if new(u256.Uint).Div(product, amount).Eq(sqrtPX96) { - denominator := new(u256.Uint).Add(numerator1, product) + return getNextPriceAmount0Add(sqrtPX96, liquidity, amount) + } - if denominator.Gte(numerator1) { - return u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator) - } - } + return getNextPriceAmount0Remove(sqrtPX96, liquidity, amount) +} - divValue := new(u256.Uint).Div(numerator1, sqrtPX96) - addValue := new(u256.Uint).Add(divValue, amount) - return u256.DivRoundingUp(numerator1, addValue) +// getNextPriceAmount1Add calculates the next sqrt price when adding token1. +// Preserves rounding-down logic for the final result. +func getNextPriceAmount1Add( + sqrtPX96, liquidity, amount *u256.Uint, +) *u256.Uint { + quotient := u256.Zero() + + if amount.Lte(max160) { + value := u256.Zero().Lsh(amount, Q96_RESOLUTION) + quotient = quotient.Div(value, liquidity) } else { - cond1 := new(u256.Uint).Div(product, amount).Eq(sqrtPX96) - cond2 := numerator1.Gt(product) + quotient = u256.MulDiv(amount, q96, liquidity) + } - if !(cond1 && cond2) { - panic("invalid pool sqrt price calculation: product/amount != sqrtPX96 or numerator1 <= product") - } + res := new(u256.Uint).Add(sqrtPX96, quotient) + if res.Gt(max160) { + panic(errSqrtPriceQuotientOverflow) + } - denominator := new(u256.Uint).Sub(numerator1, product) - nextSqrtPrice := u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator) - max160 := u256.MustFromDecimal(MAX_UINT160) - if nextSqrtPrice.Gt(max160) { - panic("nextSqrtPrice overflows uint160") - } - return nextSqrtPrice + return res +} + +// getNextPriceAmount1Remove calculates the next sqrt price when removing token1. +// Preserves rounding-down logic for the final result. +func getNextPriceAmount1Remove( + sqrtPX96, liquidity, amount *u256.Uint, +) *u256.Uint { + quotient := u256.Zero() + + if amount.Lte(max160) { + value := new(u256.Uint).Lsh(amount, Q96_RESOLUTION) + quotient = u256.DivRoundingUp(value, liquidity) + } else { + quotient = u256.MulDivRoundingUp(amount, q96, liquidity) + } + + if !sqrtPX96.Gt(quotient) { + panic(errSqrtPriceExceedsQuotient) } + + res := new(u256.Uint).Sub(sqrtPX96, quotient) + if res.Gt(max160) { + // Force mask into uint160 + mask := new(u256.Uint).Lsh(u256.One(), Q160_RESOLUTION) + mask = mask.Sub(mask, u256.One()) + res = res.And(res, mask) + } + return res } // sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown calculates the next square root price @@ -98,49 +171,13 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( // - The function uses high-precision math (MulDiv and DivRoundingUp) to handle division and prevent precision loss. // - The function validates input conditions and panics if the state is invalid. func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( - sqrtPX96 *u256.Uint, // uint160 - liquidity *u256.Uint, // uint1288 - amount *u256.Uint, // uint256 + sqrtPX96, liquidity, amount *u256.Uint, add bool, -) *u256.Uint { // uint160 - quotient := u256.Zero() - max160 := u256.MustFromDecimal(MAX_UINT160) - - // if we're adding (subtracting), rounding down requires rounding the quotient down (up) - // in both cases, avoid a mulDiv for most inputs +) *u256.Uint { if add { - if amount.Lte(u256.MustFromDecimal(MAX_UINT160)) { - value := new(u256.Uint).Lsh(amount, Q96_RESOLUTION) - quotient = new(u256.Uint).Div(value, liquidity) - } else { - quotient = u256.MulDiv(amount, u256.MustFromDecimal(Q96), liquidity) - } - - res := new(u256.Uint).Add(sqrtPX96, quotient) - if res.Gt(max160) { - panic("GetNextSqrtPriceFromAmount1RoundingDown sqrtPx96 + quotient overflow uint160") - } - return res - } else { - if amount.Lte(u256.MustFromDecimal(MAX_UINT160)) { - value := new(u256.Uint).Lsh(amount, Q96_RESOLUTION) - quotient = u256.DivRoundingUp(value, liquidity) - } else { - quotient = u256.MulDivRoundingUp(amount, u256.MustFromDecimal(Q96), liquidity) - } - - if !(sqrtPX96.Gt(quotient)) { - panic("sqrt price exceeds calculated quotient") - } - - res := new(u256.Uint).Sub(sqrtPX96, quotient) - if res.Gt(max160) { - mask := new(u256.Uint).Lsh(u256.One(), Q160_RESOLUTION) - mask = mask.Sub(mask, u256.One()) - res = res.And(res, mask) - } - return res + return getNextPriceAmount1Add(sqrtPX96, liquidity, amount) } + return getNextPriceAmount1Remove(sqrtPX96, liquidity, amount) } // sqrtPriceMathGetNextSqrtPriceFromInput calculates the next square root price @@ -160,24 +197,22 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( // Returns: // - The price after adding amountIn, depending on zeroForOne func sqrtPriceMathGetNextSqrtPriceFromInput( - sqrtPX96 *u256.Uint, - liquidity *u256.Uint, - amountIn *u256.Uint, + sqrtPX96, liquidity, amountIn *u256.Uint, zeroForOne bool, ) *u256.Uint { if sqrtPX96.IsZero() { - panic("sqrtPX96 should not be zero") + panic(errSqrtPriceZero) } if liquidity.IsZero() { - panic("liquidity should not be zero") + panic(errLiquidityZero) } if zeroForOne { return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountIn, true) - } else { - return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true) } + + return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true) } // sqrtPriceMathGetNextSqrtPriceFromOutput calculates the next square root price @@ -207,24 +242,22 @@ func sqrtPriceMathGetNextSqrtPriceFromInput( // - `sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown` for Token0 -> Token1. // - `sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp` for Token1 -> Token0. func sqrtPriceMathGetNextSqrtPriceFromOutput( - sqrtPX96 *u256.Uint, - liquidity *u256.Uint, - amountOut *u256.Uint, + sqrtPX96, liquidity, amountOut *u256.Uint, zeroForOne bool, ) *u256.Uint { if sqrtPX96.IsZero() { - panic("sqrtPX96 should not be zero") + panic(errSqrtPriceZero) } if liquidity.IsZero() { - panic("liquidity should not be zero") + panic(errLiquidityZero) } if zeroForOne { return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountOut, false) - } else { - return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false) } + + return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false) } // sqrtPriceMathGetAmount0DeltaHelper calculates the absolute difference between the amounts of token0 in two @@ -247,9 +280,7 @@ func sqrtPriceMathGetNextSqrtPriceFromOutput( // - The result is calculated using high-precision fixed-point arithmetic. // - Rounding is applied based on the roundUp parameter. func sqrtPriceMathGetAmount0DeltaHelper( - sqrtRatioAX96 *u256.Uint, - sqrtRatioBX96 *u256.Uint, - liquidity *u256.Uint, + sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint, roundUp bool, ) *u256.Uint { if sqrtRatioAX96.Gt(sqrtRatioBX96) { @@ -260,16 +291,16 @@ func sqrtPriceMathGetAmount0DeltaHelper( numerator2 := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) if !(sqrtRatioAX96.Gt(u256.Zero())) { - panic("sqrtRatioAX96 must be greater than zero") + panic(errSqrtRatioAX96NotPositive) } if roundUp { value := u256.MulDivRoundingUp(numerator1, numerator2, sqrtRatioBX96) return u256.DivRoundingUp(value, sqrtRatioAX96) - } else { - value := u256.MulDiv(numerator1, numerator2, sqrtRatioBX96) - return new(u256.Uint).Div(value, sqrtRatioAX96) } + + value := u256.MulDiv(numerator1, numerator2, sqrtRatioBX96) + return new(u256.Uint).Div(value, sqrtRatioAX96) } // sqrtPriceMathGetAmount1DeltaHelper calculates the absolute difference between the amounts of token1 in two @@ -291,9 +322,7 @@ func sqrtPriceMathGetAmount0DeltaHelper( // - Rounding is applied based on the roundUp parameter. // - The function swaps sqrtRatioAX96 and sqrtRatioBX96 if sqrtRatioAX96 > sqrtRatioBX96. func sqrtPriceMathGetAmount1DeltaHelper( - sqrtRatioAX96 *u256.Uint, - sqrtRatioBX96 *u256.Uint, - liquidity *u256.Uint, + sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint, roundUp bool, ) *u256.Uint { if sqrtRatioAX96.Gt(sqrtRatioBX96) { @@ -302,10 +331,10 @@ func sqrtPriceMathGetAmount1DeltaHelper( diff := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) if roundUp { - return u256.MulDivRoundingUp(liquidity, diff, u256.MustFromDecimal(Q96)) - } else { - return u256.MulDiv(liquidity, diff, u256.MustFromDecimal(Q96)) + return u256.MulDivRoundingUp(liquidity, diff, q96) } + + return u256.MulDiv(liquidity, diff, q96) } // SqrtPriceMathGetAmount0DeltaStr calculates the difference in the amount of token0 @@ -332,26 +361,26 @@ func sqrtPriceMathGetAmount1DeltaHelper( // - For negative liquidity, rounding is always down. // - For positive liquidity, rounding is always up. func SqrtPriceMathGetAmount0DeltaStr( - sqrtRatioAX96 *u256.Uint, - sqrtRatioBX96 *u256.Uint, + sqrtRatioAX96, sqrtRatioBX96 *u256.Uint, liquidity *i256.Int, ) string { if liquidity.IsNeg() { u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false) if u.Gt(u256.MustFromDecimal(MAX_INT256)) { // if u > (2**255 - 1), cannot cast to int256 - panic("SqrtPriceMathGetAmount0DeltaStr: overflow") + panic(errAmount0DeltaOverflow) } i := i256.FromUint256(u) return i256.Zero().Neg(i).ToString() - } else { - u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) - if u.Gt(u256.MustFromDecimal(MAX_INT256)) { - // if u > (2**255 - 1), cannot cast to int256 - panic("SqrtPriceMathGetAmount0DeltaStr: overflow") - } - return i256.FromUint256(u).ToString() } + + u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic(errAmount0DeltaOverflow) + } + + return i256.FromUint256(u).ToString() } // SqrtPriceMathGetAmount1DeltaStr calculates the difference in the amount of token1 @@ -375,24 +404,24 @@ func SqrtPriceMathGetAmount0DeltaStr( // - For negative liquidity, rounding is always down. // - For positive liquidity, rounding is always up. func SqrtPriceMathGetAmount1DeltaStr( - sqrtRatioAX96 *u256.Uint, - sqrtRatioBX96 *u256.Uint, + sqrtRatioAX96, sqrtRatioBX96 *u256.Uint, liquidity *i256.Int, ) string { if liquidity.IsNeg() { u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false) if u.Gt(u256.MustFromDecimal(MAX_INT256)) { // if u > (2**255 - 1), cannot cast to int256 - panic("SqrtPriceMathGetAmount1DeltaStr: overflow") + panic(errAmount1DeltaOverflow) } i := i256.FromUint256(u) return i256.Zero().Neg(i).ToString() - } else { - u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) - if u.Gt(u256.MustFromDecimal(MAX_INT256)) { - // if u > (2**255 - 1), cannot cast to int256 - panic("SqrtPriceMathGetAmount1DeltaStr: overflow") - } - return i256.FromUint256(u).ToString() } + + u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic(errAmount1DeltaOverflow) + } + + return i256.FromUint256(u).ToString() } diff --git a/contract/p/gnoswap/gnsmath/sqrt_price_math_test.gno b/contract/p/gnoswap/gnsmath/sqrt_price_math_test.gno index 942aa302c..114304187 100644 --- a/contract/p/gnoswap/gnsmath/sqrt_price_math_test.gno +++ b/contract/p/gnoswap/gnsmath/sqrt_price_math_test.gno @@ -90,38 +90,26 @@ func TestSqrtPriceMathGetAmount0DeltaStr(t *testing.T) { }) t.Run("panic overflow when getting amount0 with positive liquidity", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Error("Expected panic for overflow amount0") - } else { - uassert.Equal(t, "SqrtPriceMathGetAmount0DeltaStr: overflow", r) - } - }() - // Inputs to trigger panic sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value sqrtRatioBX96 := u256.MustFromDecimal("340282366920938463463374607431768211455") // very high value(2^128-1) liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) - SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + uassert.PanicsWithMessage(t, errAmount0DeltaOverflow.Error(), func() { + SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) }) t.Run("panic overflow when getting amount0 with negative liquidity", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Error("Expected panic for overflow amount0") - } else { - uassert.Equal(t, "SqrtPriceMathGetAmount0DeltaStr: overflow", r) - } - }() - // Inputs to trigger panic sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value sqrtRatioBX96 := u256.MustFromDecimal("340282366920938463463374607431768211455") // very high value(2^128-1) liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) liquidity = liquidity.Neg(liquidity) // Make liquidity negative - SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + uassert.PanicsWithMessage(t, errAmount0DeltaOverflow.Error(), func() { + SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) }) } @@ -151,38 +139,26 @@ func TestSqrtPriceMathGetAmount1DeltaStr(t *testing.T) { }) t.Run("panic overflow when getting amount1 with positive liquidity", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Error("Expected panic for overflow amount1") - } else { - uassert.Equal(t, "SqrtPriceMathGetAmount1DeltaStr: overflow", r) - } - }() - // Inputs to trigger panic sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950335") // slightly below Q96 liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) - SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + uassert.PanicsWithMessage(t, errAmount1DeltaOverflow.Error(), func() { + SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) }) t.Run("panic overflow when getting amount1 with negative liquidity", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Error("Expected panic for overflow amount1") - } else { - uassert.Equal(t, "SqrtPriceMathGetAmount1DeltaStr: overflow", r) - } - }() - // Inputs to trigger panic sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950335") // slightly below Q96 liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) liquidity = liquidity.Neg(liquidity) // Make liquidity negative - SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + uassert.PanicsWithMessage(t, errAmount1DeltaOverflow.Error(), func() { + SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) }) } diff --git a/contract/p/gnoswap/gnsmath/swap_math.gno b/contract/p/gnoswap/gnsmath/swap_math.gno index 4334bd558..6217f6dea 100644 --- a/contract/p/gnoswap/gnsmath/swap_math.gno +++ b/contract/p/gnoswap/gnsmath/swap_math.gno @@ -28,15 +28,17 @@ func SwapMathComputeSwapStepStr( amountRemaining *i256.Int, feePips uint64, ) (string, string, string, string) { - if sqrtRatioCurrentX96 == nil || sqrtRatioTargetX96 == nil || liquidity == nil || amountRemaining == nil { + if sqrtRatioCurrentX96 == nil || sqrtRatioTargetX96 == nil || + liquidity == nil || amountRemaining == nil { panic("SwapMathComputeSwapStepStr: invalid input") } + // zeroForOne determines swap direction based on the relationship of current vs. target zeroForOne := sqrtRatioCurrentX96.Gte(sqrtRatioTargetX96) // POSTIVIE == EXACT_IN => Estimated AmountOut // NEGATIVE == EXACT_OUT => Estimated AmountIn - exactIn := !(amountRemaining.IsNeg()) // amountRemaining >= 0 + exactIn := !amountRemaining.IsNeg() sqrtRatioNextX96 := u256.Zero() amountIn := u256.Zero() @@ -44,74 +46,48 @@ func SwapMathComputeSwapStepStr( feeAmount := u256.Zero() if exactIn { - amountRemainingLessFee := u256.MulDiv(amountRemaining.Abs(), u256.NewUint(1000000-feePips), u256.NewUint(1000000)) - if zeroForOne { - amountIn = sqrtPriceMathGetAmount0DeltaHelper( - sqrtRatioTargetX96.Clone(), - sqrtRatioCurrentX96.Clone(), - liquidity.Clone(), - true) - } else { - amountIn = sqrtPriceMathGetAmount1DeltaHelper( - sqrtRatioCurrentX96.Clone(), - sqrtRatioTargetX96.Clone(), - liquidity.Clone(), - true) - } - - if amountRemainingLessFee.Gte(amountIn) { - sqrtRatioNextX96 = sqrtRatioTargetX96.Clone() - } else { - sqrtRatioNextX96 = sqrtPriceMathGetNextSqrtPriceFromInput( - sqrtRatioCurrentX96.Clone(), - liquidity.Clone(), - amountRemainingLessFee.Clone(), - zeroForOne, - ) - } + // Handle EXACT_IN scenario as a separate function + sqrtRatioNextX96, amountIn = handleExactIn( + zeroForOne, + sqrtRatioCurrentX96.Clone(), + sqrtRatioTargetX96.Clone(), + liquidity.Clone(), + amountRemaining.Abs(), // use absolute value here + feePips, + ) } else { - if zeroForOne { - amountOut = sqrtPriceMathGetAmount1DeltaHelper( - sqrtRatioTargetX96.Clone(), - sqrtRatioCurrentX96.Clone(), - liquidity.Clone(), - false) - } else { - amountOut = sqrtPriceMathGetAmount0DeltaHelper( - sqrtRatioCurrentX96.Clone(), - sqrtRatioTargetX96.Clone(), - liquidity.Clone(), - false) - } - - if amountRemaining.Abs().Gte(amountOut) { - sqrtRatioNextX96 = sqrtRatioTargetX96.Clone() - } else { - sqrtRatioNextX96 = sqrtPriceMathGetNextSqrtPriceFromOutput( - sqrtRatioCurrentX96.Clone(), - liquidity.Clone(), - amountRemaining.Abs(), - zeroForOne, - ) - } + // Handle EXACT_OUT scenario as a separate function + sqrtRatioNextX96, amountOut = handleExactOut( + zeroForOne, + sqrtRatioCurrentX96.Clone(), + sqrtRatioTargetX96.Clone(), + liquidity.Clone(), + amountRemaining.Abs(), + ) } + // isMax checks if we've hit the boundary price (target) isMax := sqrtRatioTargetX96.Eq(sqrtRatioNextX96) + // Calculate final amountIn, amountOut if needed if zeroForOne { + // If isMax && exactIn, we already have the correct amountIn if !(isMax && exactIn) { amountIn = sqrtPriceMathGetAmount0DeltaHelper( sqrtRatioNextX96.Clone(), sqrtRatioCurrentX96.Clone(), liquidity.Clone(), - true) + true, + ) } + // If isMax && !exactIn, we already have the correct amountOut if !(isMax && !exactIn) { amountOut = sqrtPriceMathGetAmount1DeltaHelper( sqrtRatioNextX96.Clone(), sqrtRatioCurrentX96.Clone(), liquidity.Clone(), - false) + false, + ) } } else { if !(isMax && exactIn) { @@ -119,25 +95,120 @@ func SwapMathComputeSwapStepStr( sqrtRatioCurrentX96.Clone(), sqrtRatioNextX96.Clone(), liquidity.Clone(), - true) + true, + ) } if !(isMax && !exactIn) { amountOut = sqrtPriceMathGetAmount0DeltaHelper( sqrtRatioCurrentX96.Clone(), sqrtRatioNextX96.Clone(), liquidity.Clone(), - false) + false, + ) } } + // If we're in EXACT_OUT mode but overcalculated 'amountOut' if !exactIn && amountOut.Gt(amountRemaining.Abs()) { amountOut = amountRemaining.Abs() } - if exactIn && !(sqrtRatioNextX96.Eq(sqrtRatioTargetX96)) { + + // Fee logic + // If exactIn and we haven't hit the target, the difference is the fee + // Else, compute fee from feePips + if exactIn && !sqrtRatioNextX96.Eq(sqrtRatioTargetX96) { feeAmount = new(u256.Uint).Sub(amountRemaining.Abs(), amountIn) } else { - feeAmount = u256.MulDivRoundingUp(amountIn, u256.NewUint(feePips), new(u256.Uint).Sub(u256.NewUint(1000000), u256.NewUint(feePips))) + feeAmount = u256.MulDivRoundingUp( + amountIn, + u256.NewUint(feePips), + new(u256.Uint).Sub(u256.NewUint(1000000), u256.NewUint(feePips)), + ) } return sqrtRatioNextX96.ToString(), amountIn.ToString(), amountOut.ToString(), feeAmount.ToString() } + +// handleExactIn processes the EXACT_IN scenario and returns the next sqrt ratio and a provisional amountIn. +func handleExactIn( + zeroForOne bool, + sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity *u256.Uint, + amountRemainingAbs *u256.Uint, + feePips uint64, +) (*u256.Uint, *u256.Uint) { + // Compute amountRemainingLessFee + amountRemainingLessFee := u256.MulDiv( + amountRemainingAbs, + u256.NewUint(1000000-feePips), + u256.NewUint(1000000), + ) + + // If we fully move to the target, how much would 'amountIn' be? + var amountIn *u256.Uint + if zeroForOne { + amountIn = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioTargetX96, // from lower to current + sqrtRatioCurrentX96, + liquidity, + true, + ) + } else { + amountIn = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioCurrentX96, + sqrtRatioTargetX96, + liquidity, + true, + ) + } + + if amountRemainingLessFee.Gte(amountIn) { + return sqrtRatioTargetX96, amountIn + } + + // We don't reach target price; use partial move + nextSqrt := sqrtPriceMathGetNextSqrtPriceFromInput( + sqrtRatioCurrentX96, + liquidity, + amountRemainingLessFee, + zeroForOne, + ) + return nextSqrt, amountIn +} + +// handleExactOut processes the EXACT_OUT scenario and returns the next sqrt ratio and a provisional amountOut. +// This function checks how much output we get if we fully move to the target, +// then decides if we can reach that target or need a partial move. +func handleExactOut( + zeroForOne bool, + sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity *u256.Uint, + amountRemainingAbs *u256.Uint, +) (*u256.Uint, *u256.Uint) { + var amountOut *u256.Uint + if zeroForOne { + amountOut = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioTargetX96, + sqrtRatioCurrentX96, + liquidity, + false, + ) + } else { + amountOut = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioCurrentX96, + sqrtRatioTargetX96, + liquidity, + false, + ) + } + + if amountRemainingAbs.Gte(amountOut) { + return sqrtRatioTargetX96, amountOut + } + + nextSqrt := sqrtPriceMathGetNextSqrtPriceFromOutput( + sqrtRatioCurrentX96, + liquidity, + amountRemainingAbs, + zeroForOne, + ) + return nextSqrt, amountOut +}