diff --git a/README.md b/README.md deleted file mode 100644 index f55bb4278..000000000 --- a/README.md +++ /dev/null @@ -1 +0,0 @@ -# Gnoswap Contract diff --git a/contract/p/gnoswap/consts/consts.gno b/contract/p/gnoswap/consts/consts.gno new file mode 100644 index 000000000..a4426e547 --- /dev/null +++ b/contract/p/gnoswap/consts/consts.gno @@ -0,0 +1,137 @@ +package consts + +import ( + "std" +) + +// GNOSWAP SERVICE +const ( + ADMIN std.Address = "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d" + DEV_OPS std.Address = "g1mjvd83nnjee3z2g7683er55me9f09688pd4mj9" + TOKEN_REGISTER std.Address = "g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5" + + TOKEN_REGISTER_NAMESPACE string = "gno.land/r/g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5" + + BLOCK_GENERATION_INTERVAL int64 = 2 // seconds +) + +// WRAP & UNWRAP +const ( + GNOT string = "gnot" + UGNOT string = "ugnot" + WRAPPED_WUGNOT string = "gno.land/r/demo/wugnot" + + // defined in https://github.com/gnolang/gno/blob/81a88a2976ba9f2f9127ebbe7fb7d1e1f7fa4bd4/examples/gno.land/r/demo/wugnot/wugnot.gno#L19 + UGNOT_MIN_DEPOSIT_TO_WRAP uint64 = 1000 +) + +// CONTRACT PATH & ADDRESS +const ( + POOL_PATH string = "gno.land/r/gnoswap/v1/pool" + POOL_ADDR std.Address = "g148tjamj80yyrm309z7rk690an22thd2l3z8ank" + + POSITION_PATH string = "gno.land/r/gnoswap/v1/position" + POSITION_ADDR std.Address = "g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5" + + ROUTER_PATH string = "gno.land/r/gnoswap/v1/router" + ROUTER_ADDR std.Address = "g1lm2l7tf49h3mykesct7rhfml30yx8dw5xrval7" + + STAKER_PATH string = "gno.land/r/gnoswap/v1/staker" + STAKER_ADDR std.Address = "g1cceshmzzlmrh7rr3z30j2t5mrvsq9yccysw9nu" + + GNS_PATH string = "gno.land/r/gnoswap/v1/gns" + GNS_ADDR std.Address = "g1jgqwaa2le3yr63d533fj785qkjspumzv22ys5m" + + GNFT_PATH string = "gno.land/r/gnoswap/v1/gnft" + GNFT_ADDR std.Address = "g1wxv2rdfn53qc84nt3nn646f9yh3nly8lm7j89t" + + WUGNOT_PATH string = "gno.land/r/demo/wugnot" + WUGNOT_ADDR std.Address = "g1pf6dv9fjk3rn0m4jjcne306ga4he3mzmupfjl6" + + EMISSION_PATH string = "gno.land/r/gnoswap/v1/emission" + EMISSION_ADDR std.Address = "g10xg6559w9e93zfttlhvdmaaa0er3zewcr7nh20" + + PROTOCOL_FEE_PATH string = "gno.land/r/gnoswap/v1/protocol_fee" + PROTOCOL_FEE_ADDR std.Address = "g1f7wpek7q67tkns27sw495u5yuu3a5wwjxw5l6l" + + COMMUNITY_POOL_PATH string = "gno.land/r/gnoswap/v1/community_pool" + COMMUNITY_POOL_ADDR std.Address = "g100fnnlz5eh87p5hvwt8pf279lxaelm8k8md049" + + GOV_XGNS_PATH string = "gno.land/r/gnoswap/v1/gov/xgns" + GOV_XGNS_ADDR std.Address = "g1wwh55uwzlz2zzr2qcvvxf83qhcvmx2t8779l9r" + + GOV_STAKER_PATH string = "gno.land/r/gnoswap/v1/gov/staker" + GOV_STAKER_ADDR std.Address = "g17e3ykyqk9jmqe2y9wxe9zhep3p7cw56davjqwa" + + GOV_GOVERNANCE_PATH string = "gno.land/r/gnoswap/v1/gov/governance" + GOV_GOVERNANCE_ADDR std.Address = "g17s8w2ve7k85fwfnrk59lmlhthkjdted8whvqxd" + + COMMON_PATH string = "gno.land/r/gnoswap/v1/common" + COMMON_ADDR std.Address = "g14ytarn5u7h3xywygt8hzhs3m23frljz72ta9xk" + + LAUNCHPAD_PATH string = "gno.land/r/gnoswap/v1/launchpad" + LAUNCHPAD_ADDR std.Address = "g122mau2lp2rc0scs8d27pkkuys4w54mdy2tuer3" + + INIT_REGISTER_PATH string = "gno.land/r/g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5/v2/register_gnodev" +) + +// NUMBER +const ( + // calculated by https://mathiasbynens.be/demo/integer-range + MAX_UINT8 string = "255" + UINT8_MAX uint8 = 255 + + MAX_UINT16 string = "65535" + UINT16_MAX uint16 = 65535 + + MAX_UINT32 string = "4294967295" + UINT32_MAX uint32 = 4294967295 + + MAX_UINT64 string = "18446744073709551615" + UINT64_MAX uint64 = 18446744073709551615 + + MAX_UINT128 string = "340282366920938463463374607431768211455" + MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" + MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + + MAX_INT128 string = "170141183460469231731687303715884105727" + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" + + // Tick Related + MIN_TICK int32 = -887272 + MAX_TICK int32 = 887272 + + MIN_SQRT_RATIO string = "4295128739" // same as TickMathGetSqrtRatioAtTick(MIN_TICK) + MAX_SQRT_RATIO string = "1461446703485210103287273052203988822378723970342" // same as TickMathGetSqrtRatioAtTick(MAX_TICK) + + MIN_PRICE string = "4295128740" // MIN_SQRT_RATIO + 1 + MAX_PRICE string = "1461446703485210103287273052203988822378723970341" // MAX_SQRT_RATIO - 1 + + // ETC + Q64 string = "18446744073709551616" // 2 ** 64 + Q96 string = "79228162514264337593543950336" // 2 ** 96 + Q128 string = "340282366920938463463374607431768211456" // 2 ** 128 + + Q96_RESOLUTION uint = 96 + Q128_RESOLUTION uint = 128 + Q160_RESOLUTION uint = 160 +) + +// TIMESTAMP & DAY +const ( + SECOND_IN_MILLISECOND = 1000 + + // in seconds + TIMESTAMP_MINUTE = 60 + TIMESTAMP_HOUR = 3600 + TIMESTAMP_DAY = 86400 + TIMESTAMP_YEAR = 31536000 + + DAY_PER_YEAR = 365 +) + +// ETCs +const ( + // REF: https://github.com/gnolang/gno/pull/2401#discussion_r1648064219 + ZERO_ADDRESS std.Address = "g100000000000000000000000000000000dnmcnx" +) diff --git a/contract/p/gnoswap/consts/gno.mod b/contract/p/gnoswap/consts/gno.mod new file mode 100644 index 000000000..9d7b1f9bb --- /dev/null +++ b/contract/p/gnoswap/consts/gno.mod @@ -0,0 +1 @@ +module gno.land/p/gnoswap/consts diff --git a/contract/p/gnoswap/gnsmath/_helper_test.gno b/contract/p/gnoswap/gnsmath/_helper_test.gno new file mode 100644 index 000000000..fd43cfb30 --- /dev/null +++ b/contract/p/gnoswap/gnsmath/_helper_test.gno @@ -0,0 +1,39 @@ +package gnsmath + +import ( + "testing" +) + +func shouldEQ(t *testing.T, got, expected interface{}) { + if got != expected { + t.Errorf("got %v, expected %v", got, expected) + } +} + +func shouldNEQ(t *testing.T, got, expected interface{}) { + if got == expected { + t.Errorf("got %v, didn't expected %v", got, expected) + } +} + +func shouldPanic(t *testing.T, f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic") + } + }() + f() +} + +func shouldPanicWithMsg(t *testing.T, f func(), msg string) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } else { + if r != msg { + t.Errorf("excepted panic(%v), got(%v)", msg, r) + } + } + }() + f() +} diff --git a/contract/p/gnoswap/gnsmath/bit_math.gno b/contract/p/gnoswap/gnsmath/bit_math.gno new file mode 100644 index 000000000..b679cd03a --- /dev/null +++ b/contract/p/gnoswap/gnsmath/bit_math.gno @@ -0,0 +1,65 @@ +package gnsmath + +import ( + u256 "gno.land/p/gnoswap/uint256" +) + +type bitShift struct { + bitPattern *u256.Uint + shift uint +} + +func BitMathMostSignificantBit(x *u256.Uint) uint8 { + if x.IsZero() { + panic("BitMathMostSignificantBit: x should not be zero") + } + + shifts := []bitShift{ + {u256.MustFromDecimal(Q128), 128}, // 2^128 + {u256.MustFromDecimal(Q64), 64}, // 2^64 + {u256.NewUint(0x100000000), 32}, + {u256.NewUint(0x10000), 16}, + {u256.NewUint(0x100), 8}, + {u256.NewUint(0x10), 4}, + {u256.NewUint(0x4), 2}, + {u256.NewUint(0x2), 1}, + } + + r := uint8(0) + for _, s := range shifts { + if x.Gte(s.bitPattern) { + x = new(u256.Uint).Rsh(x, s.shift) + r += uint8(s.shift) + } + } + + return r +} + +func BitMathLeastSignificantBit(x *u256.Uint) uint8 { + if x.IsZero() { + panic("BitMathLeastSignificantBit: x should not be zero") + } + + shifts := []bitShift{ + {u256.MustFromDecimal(MAX_UINT128), 128}, + {u256.MustFromDecimal(MAX_UINT64), 64}, + {u256.MustFromDecimal(MAX_UINT32), 32}, + {u256.MustFromDecimal(MAX_UINT16), 16}, + {u256.MustFromDecimal(MAX_UINT8), 8}, + {u256.NewUint(0xf), 4}, + {u256.NewUint(0x3), 2}, + {u256.NewUint(0x1), 1}, + } + + r := uint8(255) + for _, s := range shifts { + if new(u256.Uint).And(x, s.bitPattern).Gt(u256.Zero()) { + r -= uint8(s.shift) + } else { + x = new(u256.Uint).Rsh(x, s.shift) + } + } + + return r +} diff --git a/contract/p/gnoswap/gnsmath/bit_math_test.gno b/contract/p/gnoswap/gnsmath/bit_math_test.gno new file mode 100644 index 000000000..6f1fb71dc --- /dev/null +++ b/contract/p/gnoswap/gnsmath/bit_math_test.gno @@ -0,0 +1,71 @@ +package gnsmath + +import ( + "testing" + + u256 "gno.land/p/gnoswap/uint256" +) + +func TestBitMathMostSignificantBit(t *testing.T) { + t.Run("0", func(t *testing.T) { + shouldPanic( + t, + func() { + BitMathMostSignificantBit(u256.Zero()) + }, + ) + }) + + t.Run("1", func(t *testing.T) { + shouldEQ(t, BitMathMostSignificantBit(u256.One()), uint8(0)) + }) + + t.Run("2", func(t *testing.T) { + shouldEQ(t, BitMathMostSignificantBit(u256.NewUint(2)), uint8(1)) + }) + + t.Run("all powers of 2", func(t *testing.T) { + for i := 0; i < 256; i++ { + num := u256.Zero() + num.Lsh(u256.One(), uint(i)) + shouldEQ(t, BitMathMostSignificantBit(num), uint8(i)) + } + }) + + t.Run("uint256(-1)", func(t *testing.T) { + // BigNumber.from(2).pow(256).sub(1)) + shouldEQ(t, BitMathMostSignificantBit(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")), uint8(255)) + }) +} + +func TestBitMathLeastSignificantBit(t *testing.T) { + t.Run("0", func(t *testing.T) { + shouldPanic( + t, + func() { + BitMathLeastSignificantBit(u256.Zero()) + }, + ) + }) + + t.Run("1", func(t *testing.T) { + shouldEQ(t, BitMathLeastSignificantBit(u256.One()), uint8(0)) + }) + + t.Run("2", func(t *testing.T) { + shouldEQ(t, BitMathLeastSignificantBit(u256.NewUint(2)), uint8(1)) + }) + + t.Run("all powers of 2", func(t *testing.T) { + for i := 0; i < 256; i++ { + num := u256.Zero() + num.Lsh(u256.One(), uint(i)) + shouldEQ(t, BitMathLeastSignificantBit(num), uint8(i)) + } + }) + + t.Run("uint256(-1)", func(t *testing.T) { + // BigNumber.from(2).pow(256).sub(1)) + shouldEQ(t, BitMathLeastSignificantBit(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")), uint8(0)) + }) +} diff --git a/contract/p/gnoswap/gnsmath/consts.gno b/contract/p/gnoswap/gnsmath/consts.gno new file mode 100644 index 000000000..11609bed2 --- /dev/null +++ b/contract/p/gnoswap/gnsmath/consts.gno @@ -0,0 +1,16 @@ +package gnsmath + +const ( + MAX_UINT8 string = "255" + MAX_UINT16 string = "65535" + MAX_UINT32 string = "4294967295" + MAX_UINT64 string = "18446744073709551615" + MAX_UINT128 string = "340282366920938463463374607431768211455" + MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" + MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819967" + + Q64 string = "18446744073709551616" // 2 ** 64 + Q96 string = "79228162514264337593543950336" // 2 ** 96 + Q128 string = "340282366920938463463374607431768211456" // 2 ** 128 +) diff --git a/contract/p/gnoswap/gnsmath/gno.mod b/contract/p/gnoswap/gnsmath/gno.mod new file mode 100644 index 000000000..1949733a3 --- /dev/null +++ b/contract/p/gnoswap/gnsmath/gno.mod @@ -0,0 +1 @@ +module gno.land/p/gnoswap/gnsmath diff --git a/contract/p/gnoswap/gnsmath/sqrt_price_math.gno b/contract/p/gnoswap/gnsmath/sqrt_price_math.gno new file mode 100644 index 000000000..5126bf611 --- /dev/null +++ b/contract/p/gnoswap/gnsmath/sqrt_price_math.gno @@ -0,0 +1,398 @@ +package gnsmath + +import ( + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +var ( + Q96_RESOLUTION = uint(96) + Q160_RESOLUTION = uint(160) +) + +// sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp calculates the next square root price +// based on the amount of token0 added or removed from the pool. +// NOTE: Always rounds up, because in the exact output case (increasing price) we need to move the price at least +// far enough to get the desired output amount, and in the exact input case (decreasing price) we need to move the +// price less in order to not send too much output. +// The most precise formula for this is liquidity * sqrtPX96 / (liquidity +- amount * sqrtPX96), +// if this is impossible because of overflow, we calculate liquidity / (liquidity / sqrtPX96 +- amount). +// +// Parameters: +// - sqrtPX96: The current square root price as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - amount: The amount of token0 to be added or removed from the pool (uint256). +// - add: A boolean indicating whether the amount of token0 is being added (true) or removed (false). +// +// Returns: +// - The price after adding or removing amount, depending on add +// +// Notes: +// - When `add` is true, the function calculates the new square root price after adding `amount` of token0. +// - When `add` is false, the function calculates the new square root price after removing `amount` of token0. +// - The function uses high-precision math (MulDivRoundingUp, DivRoundingUp) to handle division rounding issues. +// - The function validates input conditions and panics if the state is invalid. +func sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( + sqrtPX96 *u256.Uint, + liquidity *u256.Uint, + amount *u256.Uint, + add bool, +) *u256.Uint { + // we short circuit amount == 0 because the result is otherwise not guaranteed to equal the input price + if amount.IsZero() { + return sqrtPX96 + } + + numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION) + product := new(u256.Uint).Mul(amount, sqrtPX96) + + if add { + if new(u256.Uint).Div(product, amount).Eq(sqrtPX96) { + denominator := new(u256.Uint).Add(numerator1, product) + + if denominator.Gte(numerator1) { + return u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator) + } + } + + divValue := new(u256.Uint).Div(numerator1, sqrtPX96) + addValue := new(u256.Uint).Add(divValue, amount) + return u256.DivRoundingUp(numerator1, addValue) + } else { + cond1 := new(u256.Uint).Div(product, amount).Eq(sqrtPX96) + cond2 := numerator1.Gt(product) + + if !(cond1 && cond2) { + panic("invalid pool sqrt price calculation: product/amount != sqrtPX96 or numerator1 <= product") + } + + denominator := new(u256.Uint).Sub(numerator1, product) + nextSqrtPrice := u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator) + max160 := u256.MustFromDecimal(MAX_UINT160) + if nextSqrtPrice.Gt(max160) { + panic("nextSqrtPrice overflows uint160") + } + return nextSqrtPrice + } +} + +// sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown calculates the next square root price +// based on the amount of token1 added or removed from the pool, with rounding down. +// NOTE: Always rounds down, because in the exact output case (decreasing price) we need to move the price at least +// far enough to get the desired output amount, and in the exact input case (increasing price) we need to move the +// price less in order to not send too much output. +// The formula we compute is within <1 wei of the lossless version: sqrtPX96 +- amount / liquidity +// +// Parameters: +// - sqrtPX96: The current square root price as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - amount: The amount of token1 to be added or removed from the pool (uint256). +// - add: A boolean indicating whether the amount of token1 is being added (true) or removed (false). +// +// Returns: +// - The next square root price as a Q96 fixed-point number (uint160). +// +// Notes: +// - When `add` is true, the function calculates the new square root price after adding `amount` of token1. +// - When `add` is false, the function calculates the new square root price after removing `amount` of token1. +// - The function uses high-precision math (MulDiv and DivRoundingUp) to handle division and prevent precision loss. +// - The function validates input conditions and panics if the state is invalid. +func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( + sqrtPX96 *u256.Uint, // uint160 + liquidity *u256.Uint, // uint1288 + amount *u256.Uint, // uint256 + add bool, +) *u256.Uint { // uint160 + quotient := u256.Zero() + max160 := u256.MustFromDecimal(MAX_UINT160) + + // if we're adding (subtracting), rounding down requires rounding the quotient down (up) + // in both cases, avoid a mulDiv for most inputs + if add { + if amount.Lte(u256.MustFromDecimal(MAX_UINT160)) { + value := new(u256.Uint).Lsh(amount, Q96_RESOLUTION) + quotient = new(u256.Uint).Div(value, liquidity) + } else { + quotient = u256.MulDiv(amount, u256.MustFromDecimal(Q96), liquidity) + } + + res := new(u256.Uint).Add(sqrtPX96, quotient) + if res.Gt(max160) { + panic("GetNextSqrtPriceFromAmount1RoundingDown sqrtPx96 + quotient overflow uint160") + } + return res + } else { + if amount.Lte(u256.MustFromDecimal(MAX_UINT160)) { + value := new(u256.Uint).Lsh(amount, Q96_RESOLUTION) + quotient = u256.DivRoundingUp(value, liquidity) + } else { + quotient = u256.MulDivRoundingUp(amount, u256.MustFromDecimal(Q96), liquidity) + } + + if !(sqrtPX96.Gt(quotient)) { + panic("sqrt price exceeds calculated quotient") + } + + res := new(u256.Uint).Sub(sqrtPX96, quotient) + if res.Gt(max160) { + mask := new(u256.Uint).Lsh(u256.One(), Q160_RESOLUTION) + mask = mask.Sub(mask, u256.One()) + res = res.And(res, mask) + } + return res + } +} + +// sqrtPriceMathGetNextSqrtPriceFromInput calculates the next square root price +// based on the amount of token0 or token1 added to the pool. +// NOTE: Always rounds up, because in the exact output case (increasing price) we need to move the price at least +// far enough to get the desired output amount, and in the exact input case (decreasing price) we need to move the +// price less in order to not send too much output. +// The most precise formula for this is liquidity * sqrtPX96 / (liquidity +- amount * sqrtPX96), +// if this is impossible because of overflow, we calculate liquidity / (liquidity / sqrtPX96 +- amount). +// +// Parameters: +// - sqrtPX96: The current square root price as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - amountIn: The amount of token0 or token1 to be added to the pool (uint256). +// - zeroForOne: A boolean indicating whether the amount is being added to token0 (true) or token1 (false). +// +// Returns: +// - The price after adding amountIn, depending on zeroForOne +func sqrtPriceMathGetNextSqrtPriceFromInput( + sqrtPX96 *u256.Uint, + liquidity *u256.Uint, + amountIn *u256.Uint, + zeroForOne bool, +) *u256.Uint { + if sqrtPX96.IsZero() { + panic("sqrtPX96 should not be zero") + } + + if liquidity.IsZero() { + panic("liquidity should not be zero") + } + + if zeroForOne { + return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountIn, true) + } else { + return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true) + } +} + +// sqrtPriceMathGetNextSqrtPriceFromOutput calculates the next square root price +// based on the amount of token0 or token1 removed from the pool. +// +// NOTE: +// - For zeroForOne == true (Token0 -> Token1): The calculation uses rounding down. +// - For zeroForOne == false (Token1 -> Token0): The calculation uses rounding up. +// +// The most precise formula for this is: +// - liquidity * sqrtPX96 / (liquidity ± amount * sqrtPX96) +// If overflow occurs, it falls back to: +// - liquidity / (liquidity / sqrtPX96 ± amount) +// +// Parameters: +// - sqrtPX96: The current square root price as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - amountOut: The amount of token0 or token1 to be removed from the pool (uint256). +// - zeroForOne: A boolean indicating whether the amount is being removed from token0 (true) or token1 (false). +// +// Returns: +// - The price after removing amountOut, depending on zeroForOne. +// +// Notes: +// - Rounding direction depends on the swap direction (zeroForOne). +// - Relies on helper functions: +// - `sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown` for Token0 -> Token1. +// - `sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp` for Token1 -> Token0. +func sqrtPriceMathGetNextSqrtPriceFromOutput( + sqrtPX96 *u256.Uint, + liquidity *u256.Uint, + amountOut *u256.Uint, + zeroForOne bool, +) *u256.Uint { + if sqrtPX96.IsZero() { + panic("sqrtPX96 should not be zero") + } + + if liquidity.IsZero() { + panic("liquidity should not be zero") + } + + if zeroForOne { + return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountOut, false) + } else { + return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false) + } +} + +// sqrtPriceMathGetAmount0DeltaHelper calculates the absolute difference between the amounts of token0 in two +// liquidity ranges defined by the square root prices sqrtRatioAX96 and sqrtRatioBX96. The difference is +// calculated relative to the range [sqrtRatioAX96, sqrtRatioBX96]. +// +// If sqrtRatioAX96 > sqrtRatioBX96, their values are swapped to ensure sqrtRatioAX96 is the lower bound. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the range as a Q96 fixed-point number (uint160). +// - sqrtRatioBX96: The upper bound of the range as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - roundUp: A boolean indicating whether the result should be rounded up (true) or down (false). +// +// Returns: +// - The absolute difference between the amounts of token0 in the two ranges as a uint256. +// +// Notes: +// - If sqrtRatioAX96 is zero or negative, the function panics. +// - The result is calculated using high-precision fixed-point arithmetic. +// - Rounding is applied based on the roundUp parameter. +func sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioAX96 *u256.Uint, + sqrtRatioBX96 *u256.Uint, + liquidity *u256.Uint, + roundUp bool, +) *u256.Uint { + if sqrtRatioAX96.Gt(sqrtRatioBX96) { + sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 + } + + numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION) + numerator2 := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) + + if !(sqrtRatioAX96.Gt(u256.Zero())) { + panic("sqrtRatioAX96 must be greater than zero") + } + + if roundUp { + value := u256.MulDivRoundingUp(numerator1, numerator2, sqrtRatioBX96) + return u256.DivRoundingUp(value, sqrtRatioAX96) + } else { + value := u256.MulDiv(numerator1, numerator2, sqrtRatioBX96) + return new(u256.Uint).Div(value, sqrtRatioAX96) + } +} + +// sqrtPriceMathGetAmount1DeltaHelper calculates the absolute difference between the amounts of token1 in two +// liquidity ranges defined by the square root prices sqrtRatioAX96 and sqrtRatioBX96. The difference is +// calculated relative to the range [sqrtRatioAX96, sqrtRatioBX96]. +// +// If sqrtRatioAX96 > sqrtRatioBX96, their values are swapped to ensure sqrtRatioAX96 is the lower bound. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the range as a Q96 fixed-point number (uint160). +// - sqrtRatioBX96: The upper bound of the range as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - roundUp: A boolean indicating whether the result should be rounded up (true) or down (false). +// +// Returns: +// - The absolute difference between the amounts of token1 in the two ranges as a uint256. +// +// Notes: +// - Rounding is applied based on the roundUp parameter. +// - The function swaps sqrtRatioAX96 and sqrtRatioBX96 if sqrtRatioAX96 > sqrtRatioBX96. +func sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioAX96 *u256.Uint, + sqrtRatioBX96 *u256.Uint, + liquidity *u256.Uint, + roundUp bool, +) *u256.Uint { + if sqrtRatioAX96.Gt(sqrtRatioBX96) { + sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 + } + + diff := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) + if roundUp { + return u256.MulDivRoundingUp(liquidity, diff, u256.MustFromDecimal(Q96)) + } else { + return u256.MulDiv(liquidity, diff, u256.MustFromDecimal(Q96)) + } +} + +// SqrtPriceMathGetAmount0DeltaStr calculates the difference in the amount of token0 +// within a specified liquidity range defined by two square root prices (sqrtRatioAX96 and sqrtRatioBX96). +// This function returns the result as a string representation of an int256 value. +// +// If the liquidity is negative, the result is also negative. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the range as a Q96 fixed-point number (uint160). +// - sqrtRatioBX96: The upper bound of the range as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a signed Q128 fixed-point number (int128). +// +// Returns: +// - A string representation of the int256 value representing the difference in token0 amounts +// within the specified range. The value is negative if the liquidity is negative. +// +// Notes: +// - This function relies on the helper function `sqrtPriceMathGetAmount0DeltaHelper` to perform the core calculation. +// - The helper function calculates the absolute difference between token0 amounts within the range. +// - If the computed result exceeds the maximum allowable value for int256 (2**255 - 1), the function will panic +// with an appropriate overflow error. +// - The rounding behavior of the result is controlled by the `roundUp` parameter passed to the helper function: +// - For negative liquidity, rounding is always down. +// - For positive liquidity, rounding is always up. +func SqrtPriceMathGetAmount0DeltaStr( + sqrtRatioAX96 *u256.Uint, + sqrtRatioBX96 *u256.Uint, + liquidity *i256.Int, +) string { + if liquidity.IsNeg() { + u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic("SqrtPriceMathGetAmount0DeltaStr: overflow") + } + i := i256.FromUint256(u) + return i256.Zero().Neg(i).ToString() + } else { + u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic("SqrtPriceMathGetAmount0DeltaStr: overflow") + } + return i256.FromUint256(u).ToString() + } +} + +// SqrtPriceMathGetAmount1DeltaStr calculates the difference in the amount of token1 +// within a specified liquidity range defined by two square root prices (sqrtRatioAX96 and sqrtRatioBX96). +// This function returns the result as a string representation of an int256 value. +// +// If the liquidity is negative, the result is also negative. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the range as a Q96 fixed-point number (uint160). +// - sqrtRatioBX96: The upper bound of the range as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a signed Q128 fixed-point number (int128). +// +// Returns: +// - A string representation of the int256 value representing the difference in token1 amounts +// within the specified range. The value is negative if the liquidity is negative. +// +// Notes: +// - This function relies on the helper function `sqrtPriceMathGetAmount1DeltaHelper` to perform the core calculation. +// - The rounding behavior of the result is controlled by the `roundUp` parameter passed to the helper function: +// - For negative liquidity, rounding is always down. +// - For positive liquidity, rounding is always up. +func SqrtPriceMathGetAmount1DeltaStr( + sqrtRatioAX96 *u256.Uint, + sqrtRatioBX96 *u256.Uint, + liquidity *i256.Int, +) string { + if liquidity.IsNeg() { + u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic("SqrtPriceMathGetAmount1DeltaStr: overflow") + } + i := i256.FromUint256(u) + return i256.Zero().Neg(i).ToString() + } else { + u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic("SqrtPriceMathGetAmount1DeltaStr: overflow") + } + return i256.FromUint256(u).ToString() + } +} diff --git a/contract/p/gnoswap/gnsmath/sqrt_price_math_test.gno b/contract/p/gnoswap/gnsmath/sqrt_price_math_test.gno new file mode 100644 index 000000000..a40e76f95 --- /dev/null +++ b/contract/p/gnoswap/gnsmath/sqrt_price_math_test.gno @@ -0,0 +1,643 @@ +package gnsmath + +import ( + "testing" + + "gno.land/p/demo/uassert" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestSqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(t *testing.T) { + t.Run("zero amount returns same price", func(t *testing.T) { + sqrtPX96 := u256.MustFromDecimal("1000000") + liquidity := u256.MustFromDecimal("2000000") + amount := u256.Zero() + + result := sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( + sqrtPX96, + liquidity, + amount, + true, + ) + + if !result.Eq(sqrtPX96) { + t.Errorf("Expected %s, got %s", sqrtPX96.ToString(), result.ToString()) + } + }) + + t.Run("remove token0", func(t *testing.T) { + sqrtPX96 := u256.MustFromDecimal("1000000") + liquidity := u256.MustFromDecimal("2000000") + amount := u256.MustFromDecimal("500000") + + result := sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( + sqrtPX96, + liquidity, + amount, + false, + ) + + if result.Lte(sqrtPX96) { + t.Error("Price should increase when removing token0") + } + }) +} + +func TestSqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(t *testing.T) { + t.Run("add token1 small amount", func(t *testing.T) { + sqrtPX96 := u256.MustFromDecimal("1000000") + liquidity := u256.MustFromDecimal("2000000") + amount := u256.MustFromDecimal("100000") + + result := sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( + sqrtPX96, + liquidity, + amount, + true, + ) + + if result.Lte(sqrtPX96) { + t.Error("Price should increase when adding token1") + } + }) +} + +func TestSqrtPriceMathGetAmount0DeltaStr(t *testing.T) { + t.Run("positive liquidity", func(t *testing.T) { + ratioA := u256.MustFromDecimal("1000000") + ratioB := u256.MustFromDecimal("2000000") + liquidity := i256.FromUint256(u256.MustFromDecimal("5000000")) + + result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity) + + if result[0] == '-' { + t.Error("Result should be positive for positive liquidity") + } + }) + + t.Run("negative liquidity", func(t *testing.T) { + ratioA := u256.MustFromDecimal("1000000") + ratioB := u256.MustFromDecimal("2000000") + liquidity := i256.Zero().Neg(i256.FromUint256(u256.MustFromDecimal("5000000"))) + + result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity) + + if result[0] != '-' { + t.Error("Result should be negative for negative liquidity") + } + }) + + t.Run("panic overflow when getting amount0 with positive liquidity", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for overflow amount0") + } else { + uassert.Equal(t, "SqrtPriceMathGetAmount0DeltaStr: overflow", r) + } + }() + + // Inputs to trigger panic + sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value + sqrtRatioBX96 := u256.MustFromDecimal("340282366920938463463374607431768211455") // very high value(2^128-1) + liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) + + SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) + + t.Run("panic overflow when getting amount0 with negative liquidity", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for overflow amount0") + } else { + uassert.Equal(t, "SqrtPriceMathGetAmount0DeltaStr: overflow", r) + } + }() + + // Inputs to trigger panic + sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value + sqrtRatioBX96 := u256.MustFromDecimal("340282366920938463463374607431768211455") // very high value(2^128-1) + liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) + liquidity = liquidity.Neg(liquidity) // Make liquidity negative + + SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) +} + +func TestSqrtPriceMathGetAmount1DeltaStr(t *testing.T) { + t.Run("positive liquidity", func(t *testing.T) { + ratioA := u256.MustFromDecimal("1000000") + ratioB := u256.MustFromDecimal("2000000") + liquidity := i256.FromUint256(u256.MustFromDecimal("5000000")) + + result := SqrtPriceMathGetAmount1DeltaStr(ratioA, ratioB, liquidity) + + if result[0] == '-' { + t.Error("Result should be positive for positive liquidity") + } + }) + + t.Run("negative liquidity", func(t *testing.T) { + ratioA := u256.MustFromDecimal("1000000") + ratioB := u256.MustFromDecimal("2000000") + liquidity := i256.Zero().Neg(i256.FromUint256(u256.MustFromDecimal("5000000"))) + + result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity) + + if result[0] != '-' { + t.Error("Result should be negative for negative liquidity") + } + }) + + t.Run("panic overflow when getting amount1 with positive liquidity", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for overflow amount1") + } else { + uassert.Equal(t, "SqrtPriceMathGetAmount1DeltaStr: overflow", r) + } + }() + + // Inputs to trigger panic + sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value + sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950335") // slightly below Q96 + liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) + + SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) + + t.Run("panic overflow when getting amount1 with negative liquidity", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for overflow amount1") + } else { + uassert.Equal(t, "SqrtPriceMathGetAmount1DeltaStr: overflow", r) + } + }() + + // Inputs to trigger panic + sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value + sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950335") // slightly below Q96 + liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) + liquidity = liquidity.Neg(liquidity) // Make liquidity negative + + SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) +} + +func TestSqrtPriceMathGetNextSqrtPriceFromInput(t *testing.T) { + tests := []struct { + name string + sqrtPriceX96 *u256.Uint + liquidity *u256.Uint + amountIn *u256.Uint + zeroForOne bool + shouldPanic bool + panicMsg string + expectedSqrtPriceX96 string + }{ + { + name: "fails if price is zero", + sqrtPriceX96: u256.Zero(), + liquidity: u256.Zero(), + amountIn: u256.MustFromDecimal("100000000000000000"), + zeroForOne: false, + shouldPanic: true, + panicMsg: "sqrtPX96 should not be zero", + }, + { + name: "fails if liquidity is zero", + sqrtPriceX96: u256.One(), + liquidity: u256.Zero(), + amountIn: u256.MustFromDecimal("100000000000000000"), + zeroForOne: true, + shouldPanic: true, + panicMsg: "liquidity should not be zero", + }, + { + name: "fails if input amount overflows the price", + sqrtPriceX96: u256.MustFromDecimal("1461501637330902918203684832716283019655932542975"), // 2^160 - 1 + liquidity: u256.MustFromDecimal("1024"), + amountIn: u256.MustFromDecimal("1024"), + zeroForOne: false, + shouldPanic: true, + panicMsg: "sqrtPx96 + quotient overflow uint160", + }, + { + name: "any input amount cannot underflow the price", + sqrtPriceX96: u256.MustFromDecimal("1"), + liquidity: u256.MustFromDecimal("1"), + amountIn: u256.MustFromDecimal("57896044618658097711785492504343953926634992332820282019728792003956564819968"), // 2^255 + zeroForOne: true, + expectedSqrtPriceX96: "1", + }, + { + name: "returns input price if amount in is zero and zeroForOne = true", + sqrtPriceX96: u256.MustFromDecimal("79228162514264337593543950336"), + liquidity: u256.MustFromDecimal("100000000000000000"), + amountIn: u256.Zero(), + zeroForOne: true, + expectedSqrtPriceX96: "79228162514264337593543950336", + }, + { + name: "returns input price if amount in is zero and zeroForOne = false", + sqrtPriceX96: u256.MustFromDecimal("79228162514264337593543950336"), + liquidity: u256.MustFromDecimal("100000000000000000"), + amountIn: u256.Zero(), + zeroForOne: false, + expectedSqrtPriceX96: "79228162514264337593543950336", + }, + { + name: "returns the minimum price for max inputs", + sqrtPriceX96: u256.MustFromDecimal("1461501637330902918203684832716283019655932542975"), // 2^160 - 1 + liquidity: u256.MustFromDecimal("340282366920938463463374607431768211455"), // 2^128 - 1 + amountIn: u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039439137263839420088320"), + zeroForOne: true, + expectedSqrtPriceX96: "1", + }, + { + name: "input amount of 0.1 token1", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + amountIn: u256.MustFromDecimal("100000000000000000"), + zeroForOne: false, + expectedSqrtPriceX96: "87150978765690771352898345369", + }, + { + name: "input amount of 0.1 token0", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + amountIn: u256.MustFromDecimal("100000000000000000"), + zeroForOne: true, + expectedSqrtPriceX96: "72025602285694852357767227579", + }, + { + name: "amountIn > type(uint96).max and zeroForOne = true", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("10000000000000000000"), + amountIn: u256.MustFromDecimal("1267650600228229401496703205376"), // 2^128 - 1 + zeroForOne: true, + expectedSqrtPriceX96: "624999999995069620", + }, + { + name: "can return 1 with enough amountIn and zeroForOne = true", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.One(), + amountIn: u256.MustFromDecimal("57896044618658097711785492504343953926634992332820282019728792003956564819967"), + zeroForOne: true, + expectedSqrtPriceX96: "1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + sqrtPriceMathGetNextSqrtPriceFromInput(tt.sqrtPriceX96, tt.liquidity, tt.amountIn, tt.zeroForOne) + }) + } else { + actual := sqrtPriceMathGetNextSqrtPriceFromInput(tt.sqrtPriceX96, tt.liquidity, tt.amountIn, tt.zeroForOne) + uassert.Equal(t, tt.expectedSqrtPriceX96, actual.ToString()) + } + }) + } +} + +func TestSqrtPriceMathGetNextSqrtPriceFromInput2(t *testing.T) { + t.Run("zero sqrtPX96 should panic", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for zero sqrtPX96") + } + }() + + sqrtPriceMathGetNextSqrtPriceFromInput( + u256.Zero(), + u256.MustFromDecimal("1000000"), + u256.MustFromDecimal("500000"), + true, + ) + }) + + t.Run("zero liquidity should panic", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for zero liquidity") + } + }() + + sqrtPriceMathGetNextSqrtPriceFromInput( + u256.MustFromDecimal("1000000"), + u256.Zero(), + u256.MustFromDecimal("500000"), + true, + ) + }) +} + +func TestSqrtPriceMathGetNextSqrtPriceFromOutput(t *testing.T) { + tests := []struct { + name string + sqrtPriceX96 *u256.Uint + liquidity *u256.Uint + amountOut *u256.Uint + zeroForOne bool + shouldPanic bool + expectedSqrtPriceX96 string + }{ + { + name: "fails if price is zero", + sqrtPriceX96: u256.Zero(), + liquidity: u256.Zero(), + amountOut: u256.MustFromDecimal("100000000000000000"), + zeroForOne: false, + shouldPanic: true, + }, + { + name: "fails if liquidity is zero", + sqrtPriceX96: u256.One(), + liquidity: u256.Zero(), + amountOut: u256.MustFromDecimal("100000000000000000"), + zeroForOne: true, + shouldPanic: true, + }, + { + name: "fails if output amount is exactly the virtual reserves of token0", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(4), + zeroForOne: false, + shouldPanic: true, + }, + { + name: "fails if output amount is greater than virtual reserves of token0", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(5), + zeroForOne: false, + shouldPanic: true, + }, + { + name: "fails if output amount is greater than virtual reserves of token1", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(262145), + zeroForOne: true, + shouldPanic: true, + }, + { + name: "fails if output amount is exactly the virtual reserves of token1", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(262144), + zeroForOne: true, + shouldPanic: true, + }, + { + name: "succeeds if output amount is just less than the virtual reserves of token1", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(262143), + zeroForOne: true, + expectedSqrtPriceX96: "77371252455336267181195264", + }, + { + name: "puzzling echidna test", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(4), + zeroForOne: false, + shouldPanic: true, + }, + { + name: "returns input price if amount in is zero and zeroForOne = true", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("100000000000000000"), + amountOut: u256.Zero(), + zeroForOne: true, + expectedSqrtPriceX96: encodePriceSqrt("1", "1").ToString(), + }, + { + name: "returns input price if amount in is zero and zeroForOne = false", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("100000000000000000"), + amountOut: u256.Zero(), + zeroForOne: false, + expectedSqrtPriceX96: encodePriceSqrt("1", "1").ToString(), + }, + { + name: "output amount of 0.1 token1, zeroForOne = false", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + amountOut: u256.MustFromDecimal("100000000000000000"), + zeroForOne: false, + expectedSqrtPriceX96: "88031291682515930659493278152", + }, + { + name: "output amount of 0.1 token1, zeroForOne = true", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + amountOut: u256.MustFromDecimal("100000000000000000"), + zeroForOne: true, + expectedSqrtPriceX96: "71305346262837903834189555302", + }, + { + name: "reverts if amountOut is impossible in zero for one direction", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.NewUint(1), + amountOut: u256.MustFromDecimal(MAX_UINT256), + zeroForOne: true, + shouldPanic: true, + }, + { + name: "reverts if amountOut is impossible in one for zero direction", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.NewUint(1), + amountOut: u256.MustFromDecimal(MAX_UINT256), + zeroForOne: false, + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for %s", tt.name) + } + }() + sqrtPriceMathGetNextSqrtPriceFromOutput(tt.sqrtPriceX96, tt.liquidity, tt.amountOut, tt.zeroForOne) + + } else { + actual := sqrtPriceMathGetNextSqrtPriceFromOutput(tt.sqrtPriceX96, tt.liquidity, tt.amountOut, tt.zeroForOne) + uassert.Equal(t, tt.expectedSqrtPriceX96, actual.ToString()) + } + }) + } +} + +func TestSqrtPriceMathGetAmount0DeltaHelper(t *testing.T) { + tests := []struct { + name string + sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint + roundUp bool + expectedAmount0Delta string + }{ + { + name: "returns 0 if liquidity is 0", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("2", "1"), + liquidity: u256.Zero(), + roundUp: true, + expectedAmount0Delta: "0", + }, + { + name: "returns 0 if prices are equal", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("1", "1"), + liquidity: u256.Zero(), + roundUp: true, + expectedAmount0Delta: "0", + }, + { + name: "returns 0.1 amount1 for price of 1 to 1.21, roundUp = true", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("121", "100"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: true, + expectedAmount0Delta: "90909090909090910", + }, + { + name: "returns 0.1 amount1 for price of 1 to 1.21, roundUp = false", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("121", "100"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: false, + expectedAmount0Delta: "90909090909090909", + }, + { + name: "works for prices that overflow, roundUp = true", + sqrtRatioAX96: u256.MustFromDecimal("43556142965880123323311949751266331066368"), + sqrtRatioBX96: u256.MustFromDecimal("22300745198530623141535718272648361505980416"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: true, + expectedAmount0Delta: "1815437", + }, + { + name: "works for prices that overflow, roundUp = false", + sqrtRatioAX96: u256.MustFromDecimal("43556142965880123323311949751266331066368"), + sqrtRatioBX96: u256.MustFromDecimal("22300745198530623141535718272648361505980416"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: false, + expectedAmount0Delta: "1815436", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := sqrtPriceMathGetAmount0DeltaHelper(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity, tt.roundUp) + uassert.Equal(t, tt.expectedAmount0Delta, actual.ToString()) + }) + } +} + +func TestSqrtPriceMathGetAmount1DeltaHelper(t *testing.T) { + tests := []struct { + name string + sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint + roundUp bool + expectedAmount1Delta string + }{ + { + name: "returns 0 if liquidity is 0", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("2", "1"), + liquidity: u256.Zero(), + roundUp: true, + expectedAmount1Delta: "0", + }, + { + name: "returns 0 if prices are equal", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("1", "1"), + liquidity: u256.Zero(), + roundUp: true, + expectedAmount1Delta: "0", + }, + { + name: "returns 0.1 amount1 for price of 1 to 1.21, roundUp = true", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("121", "100"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: true, + expectedAmount1Delta: "100000000000000000", + }, + { + name: "returns 0.1 amount1 for price of 1 to 1.21, roundUp = false", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("121", "100"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: false, + expectedAmount1Delta: "99999999999999999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := sqrtPriceMathGetAmount1DeltaHelper(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity, tt.roundUp) + uassert.Equal(t, tt.expectedAmount1Delta, actual.ToString()) + }) + } +} + +func TestSwapComputation_SqrtP_SqrtQ_Mul_Overflow(t *testing.T) { + sqrtP := u256.MustFromDecimal("1025574284609383690408304870162715216695788925244") + liquidity := u256.MustFromDecimal("50015962439936049619261659728067971248") + amountIn := u256.MustFromDecimal("406") + zeroForOne := true + + sqrtQ := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtP, liquidity, amountIn, zeroForOne) + uassert.Equal(t, "1025574284609383582644711336373707553698163132913", sqrtQ.ToString()) + + amount0Delta := sqrtPriceMathGetAmount0DeltaHelper(sqrtQ, sqrtP, liquidity, true) + uassert.Equal(t, "406", amount0Delta.ToString()) +} + +// encodePriceSqrt calculates the sqrt((reserve1 << 192) / reserve0) +func encodePriceSqrt(reserve1, reserve0 string) *u256.Uint { + reserve1Uint := u256.MustFromDecimal(reserve1) + reserve0Uint := u256.MustFromDecimal(reserve0) + + if reserve0Uint.IsZero() { + panic("division by zero") + } + + // numerator = reserve1 * (2^192) + two192 := new(u256.Uint).Lsh(u256.NewUint(1), 192) + numerator := new(u256.Uint).Mul(reserve1Uint, two192) + + // ratioX192 = numerator / reserve0 + ratioX192 := new(u256.Uint).Div(numerator, reserve0Uint) + + // Return sqrt(ratioX192) + return sqrt(ratioX192) +} + +// sqrt computes the integer square root of a u256.Uint +func sqrt(x *u256.Uint) *u256.Uint { + if x.IsZero() { + return u256.NewUint(0) + } + + z := new(u256.Uint).Set(x) + y := new(u256.Uint).Rsh(z, 1) // Initial guess is x / 2 + + for y.Cmp(z) < 0 { + z.Set(y) + temp := new(u256.Uint).Div(x, z) + y.Add(z, temp).Rsh(y, 1) + } + return z +} diff --git a/contract/p/gnoswap/gnsmath/swap_math.gno b/contract/p/gnoswap/gnsmath/swap_math.gno new file mode 100644 index 000000000..4334bd558 --- /dev/null +++ b/contract/p/gnoswap/gnsmath/swap_math.gno @@ -0,0 +1,143 @@ +package gnsmath + +import ( + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +// SwapMathComputeSwapStepStr computes the next sqrt price, amount in, amount out, and fee amount +// Computes the result of swapping some amount in, or amount out, given the parameters of the swap +// The fee, plus the amount in, will never exceed the amount remaining if the swap's `amountSpecified` is positive +// +// input: +// - sqrtRatioCurrentX96: the current sqrt price of the pool +// - sqrtRatioTargetX96: The price that cannot be exceeded, from which the direction of the swap is inferred +// - liquidity: The usable liquidity of the pool +// - amountRemaining: How much input or output amount is remaining to be swapped in/out +// - feePips: The fee taken from the input amount, expressed in hundredths of a bip +// +// output: +// - sqrtRatioNextX96: The price after swapping the amount in/out, not to exceed the price target +// - amountIn: The amount to be swapped in, of either token0 or token1, based on the direction of the swap +// - amountOut: The amount to be received, of either token0 or token1, based on the direction of the swap +// - feeAmount: The amount of input that will be taken as a fee +func SwapMathComputeSwapStepStr( + sqrtRatioCurrentX96 *u256.Uint, + sqrtRatioTargetX96 *u256.Uint, + liquidity *u256.Uint, + amountRemaining *i256.Int, + feePips uint64, +) (string, string, string, string) { + if sqrtRatioCurrentX96 == nil || sqrtRatioTargetX96 == nil || liquidity == nil || amountRemaining == nil { + panic("SwapMathComputeSwapStepStr: invalid input") + } + + zeroForOne := sqrtRatioCurrentX96.Gte(sqrtRatioTargetX96) + + // POSTIVIE == EXACT_IN => Estimated AmountOut + // NEGATIVE == EXACT_OUT => Estimated AmountIn + exactIn := !(amountRemaining.IsNeg()) // amountRemaining >= 0 + + sqrtRatioNextX96 := u256.Zero() + amountIn := u256.Zero() + amountOut := u256.Zero() + feeAmount := u256.Zero() + + if exactIn { + amountRemainingLessFee := u256.MulDiv(amountRemaining.Abs(), u256.NewUint(1000000-feePips), u256.NewUint(1000000)) + if zeroForOne { + amountIn = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioTargetX96.Clone(), + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + true) + } else { + amountIn = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioCurrentX96.Clone(), + sqrtRatioTargetX96.Clone(), + liquidity.Clone(), + true) + } + + if amountRemainingLessFee.Gte(amountIn) { + sqrtRatioNextX96 = sqrtRatioTargetX96.Clone() + } else { + sqrtRatioNextX96 = sqrtPriceMathGetNextSqrtPriceFromInput( + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + amountRemainingLessFee.Clone(), + zeroForOne, + ) + } + } else { + if zeroForOne { + amountOut = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioTargetX96.Clone(), + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + false) + } else { + amountOut = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioCurrentX96.Clone(), + sqrtRatioTargetX96.Clone(), + liquidity.Clone(), + false) + } + + if amountRemaining.Abs().Gte(amountOut) { + sqrtRatioNextX96 = sqrtRatioTargetX96.Clone() + } else { + sqrtRatioNextX96 = sqrtPriceMathGetNextSqrtPriceFromOutput( + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + amountRemaining.Abs(), + zeroForOne, + ) + } + } + + isMax := sqrtRatioTargetX96.Eq(sqrtRatioNextX96) + + if zeroForOne { + if !(isMax && exactIn) { + amountIn = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioNextX96.Clone(), + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + true) + } + if !(isMax && !exactIn) { + amountOut = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioNextX96.Clone(), + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + false) + } + } else { + if !(isMax && exactIn) { + amountIn = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioCurrentX96.Clone(), + sqrtRatioNextX96.Clone(), + liquidity.Clone(), + true) + } + if !(isMax && !exactIn) { + amountOut = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioCurrentX96.Clone(), + sqrtRatioNextX96.Clone(), + liquidity.Clone(), + false) + } + } + + if !exactIn && amountOut.Gt(amountRemaining.Abs()) { + amountOut = amountRemaining.Abs() + } + if exactIn && !(sqrtRatioNextX96.Eq(sqrtRatioTargetX96)) { + feeAmount = new(u256.Uint).Sub(amountRemaining.Abs(), amountIn) + } else { + feeAmount = u256.MulDivRoundingUp(amountIn, u256.NewUint(feePips), new(u256.Uint).Sub(u256.NewUint(1000000), u256.NewUint(feePips))) + } + + return sqrtRatioNextX96.ToString(), amountIn.ToString(), amountOut.ToString(), feeAmount.ToString() +} diff --git a/contract/p/gnoswap/gnsmath/swap_math_test.gno b/contract/p/gnoswap/gnsmath/swap_math_test.gno new file mode 100644 index 000000000..edcafcf43 --- /dev/null +++ b/contract/p/gnoswap/gnsmath/swap_math_test.gno @@ -0,0 +1,273 @@ +package gnsmath + +import ( + "testing" + + "gno.land/p/demo/uassert" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestSwapMathComputeSwapStepStr(t *testing.T) { + tests := []struct { + name string + currentX96, targetX96 *u256.Uint + liquidity *u256.Uint + amountRemaining *i256.Int + feePips uint64 + sqrtNextX96 *u256.Uint + chkSqrtNextX96 func(sqrtRatioNextX96, priceTarget *u256.Uint) + amountIn, amountOut, feeAmount string + }{ + { + name: "exact amount in that gets capped at price target in one for zero", + currentX96: encodePriceSqrt(t, "1", "1"), + targetX96: encodePriceSqrt(t, "101", "100"), + liquidity: u256.MustFromDecimal("2000000000000000000"), + amountRemaining: i256.MustFromDecimal("1000000000000000000"), + feePips: 600, + sqrtNextX96: encodePriceSqrt(t, "101", "100"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "9975124224178055", + amountOut: "9925619580021728", + feeAmount: "5988667735148", + }, + { + name: "exact amount out that gets capped at price target in one for zero", + currentX96: encodePriceSqrt(t, "1", "1"), + targetX96: encodePriceSqrt(t, "101", "100"), + liquidity: u256.MustFromDecimal("2000000000000000000"), + amountRemaining: i256.MustFromDecimal("-1000000000000000000"), + feePips: 600, + sqrtNextX96: encodePriceSqrt(t, "101", "100"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "9975124224178055", + amountOut: "9925619580021728", + feeAmount: "5988667735148", + }, + { + name: "exact amount in that is fully spent in one for zero", + currentX96: encodePriceSqrt(t, "1", "1"), + targetX96: encodePriceSqrt(t, "1000", "100"), + liquidity: u256.MustFromDecimal("2000000000000000000"), + amountRemaining: i256.MustFromDecimal("1000000000000000000"), + sqrtNextX96: encodePriceSqrt(t, "1000", "100"), + feePips: 600, + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Lte(priceTarget)) + }, + amountIn: "999400000000000000", + amountOut: "666399946655997866", + feeAmount: "600000000000000", + }, + { + name: "exact amount out that is fully received in one for zero", + currentX96: encodePriceSqrt(t, "1", "1"), + targetX96: encodePriceSqrt(t, "1000", "100"), + liquidity: u256.MustFromDecimal("2000000000000000000"), + amountRemaining: i256.MustFromDecimal("-1000000000000000000"), + feePips: 600, + sqrtNextX96: encodePriceSqrt(t, "1000", "100"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Lt(priceTarget)) + }, + amountIn: "2000000000000000000", + amountOut: "1000000000000000000", + feeAmount: "1200720432259356", + }, + { + name: "amount out is capped at the desired amount out", + currentX96: u256.MustFromDecimal("417332158212080721273783715441582"), + targetX96: u256.MustFromDecimal("1452870262520218020823638996"), + liquidity: u256.MustFromDecimal("159344665391607089467575320103"), + amountRemaining: i256.MustFromDecimal("-1"), + feePips: 1, + sqrtNextX96: u256.MustFromDecimal("417332158212080721273783715441581"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "1", + amountOut: "1", + feeAmount: "1", + }, + { + name: "target price of 1 uses partial input amount", + currentX96: u256.MustFromDecimal("2"), + targetX96: u256.MustFromDecimal("1"), + liquidity: u256.MustFromDecimal("1"), + amountRemaining: i256.MustFromDecimal("3915081100057732413702495386755767"), + feePips: 1, + sqrtNextX96: u256.MustFromDecimal("1"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "39614081257132168796771975168", + amountOut: "0", + feeAmount: "39614120871253040049813", + }, + { + name: "entire input amount taken as fee", + currentX96: u256.MustFromDecimal("2413"), + targetX96: u256.MustFromDecimal("79887613182836312"), + liquidity: u256.MustFromDecimal("1985041575832132834610021537970"), + amountRemaining: i256.MustFromDecimal("10"), + feePips: 1872, + sqrtNextX96: u256.MustFromDecimal("2413"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "0", + amountOut: "0", + feeAmount: "10", + }, + { + name: "handles intermediate insufficient liquidity in zero for one exact output case", + currentX96: u256.MustFromDecimal("20282409603651670423947251286016"), + targetX96: u256.MustFromDecimal("22310650564016837466341976414617"), + liquidity: u256.MustFromDecimal("1024"), + amountRemaining: i256.MustFromDecimal("-4"), + feePips: 3000, + sqrtNextX96: u256.MustFromDecimal("22310650564016837466341976414617"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "26215", + amountOut: "0", + feeAmount: "79", + }, + { + name: "handles intermediate insufficient liquidity in one for zero exact output case", + currentX96: u256.MustFromDecimal("20282409603651670423947251286016"), + targetX96: u256.MustFromDecimal("18254168643286503381552526157414"), + liquidity: u256.MustFromDecimal("1024"), + amountRemaining: i256.MustFromDecimal("-263000"), + feePips: 3000, + sqrtNextX96: u256.MustFromDecimal("18254168643286503381552526157414"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "1", + amountOut: "26214", + feeAmount: "1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sqrtRatioNextX96, amountIn, amountOut, feeAmount := SwapMathComputeSwapStepStr(test.currentX96, test.targetX96, test.liquidity, test.amountRemaining, test.feePips) + test.chkSqrtNextX96(u256.MustFromDecimal(sqrtRatioNextX96), test.sqrtNextX96) + uassert.Equal(t, amountIn, test.amountIn) + uassert.Equal(t, amountOut, test.amountOut) + uassert.Equal(t, feeAmount, test.feeAmount) + }) + } +} + +func TestSwapMathComputeSwapStepStrFail(t *testing.T) { + tests := []struct { + name string + currentX96, targetX96 *u256.Uint + liquidity *u256.Uint + amountRemaining *i256.Int + feePips uint64 + sqrtNextX96 *u256.Uint + chkSqrtNextX96 func(sqrtRatioNextX96, priceTarget *u256.Uint) + amountIn, amountOut, feeAmount string + shouldPanic bool + expectedMessage string + }{ + { + name: "input parameter is nil", + currentX96: nil, + targetX96: nil, + liquidity: nil, + amountRemaining: nil, + feePips: 600, + sqrtNextX96: encodePriceSqrt(t, "101", "100"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "9975124224178055", + amountOut: "9925619580021728", + feeAmount: "5988667735148", + shouldPanic: true, + expectedMessage: "SwapMathComputeSwapStepStr: invalid input", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if test.shouldPanic { + if errMsg, ok := r.(string); ok { + uassert.Equal(t, test.expectedMessage, errMsg) + } else { + t.Errorf("expected a panic with message, got: %v", r) + } + } else { + t.Errorf("unexpected panic: %v", r) + } + } else { + if test.shouldPanic { + t.Errorf("expected a panic, but none occurred") + } + } + }() + + SwapMathComputeSwapStepStr( + test.currentX96, + test.targetX96, + test.liquidity, + test.amountRemaining, + test.feePips) + }) + } +} + +// encodePriceSqrt calculates the sqrt((reserve1 << 192) / reserve0) +func encodePriceSqrt(t *testing.T, reserve1, reserve0 string) *u256.Uint { + t.Helper() + + reserve1Uint := u256.MustFromDecimal(reserve1) + reserve0Uint := u256.MustFromDecimal(reserve0) + + if reserve0Uint.IsZero() { + panic("division by zero") + } + + // numerator = reserve1 * (2^192) + two192 := new(u256.Uint).Lsh(u256.NewUint(1), 192) + numerator := new(u256.Uint).Mul(reserve1Uint, two192) + + // ratioX192 = numerator / reserve0 + ratioX192 := new(u256.Uint).Div(numerator, reserve0Uint) + + // Return sqrt(ratioX192) + return sqrt(t, ratioX192) +} + +// sqrt computes the integer square root of a u256.Uint +func sqrt(t *testing.T, x *u256.Uint) *u256.Uint { + t.Helper() + + if x.IsZero() { + return u256.NewUint(0) + } + + z := new(u256.Uint).Set(x) + y := new(u256.Uint).Rsh(z, 1) // Initial guess is x / 2 + + temp := new(u256.Uint) + for y.Cmp(z) < 0 { + z.Set(y) + temp.Div(x, z) + y.Add(z, temp).Rsh(y, 1) + } + return z +} diff --git a/contract/p/gnoswap/int256/absolute.gno b/contract/p/gnoswap/int256/absolute.gno new file mode 100644 index 000000000..fbc1f9913 --- /dev/null +++ b/contract/p/gnoswap/int256/absolute.gno @@ -0,0 +1,20 @@ +package int256 + +import ( + "gno.land/p/gnoswap/uint256" +) + +// Abs returns |z| +func (z *Int) Abs() *uint256.Uint { + return z.abs.Clone() +} + +// AbsGt returns true if |z| > x, where x is a uint256 +func (z *Int) AbsGt(x *uint256.Uint) bool { + return z.abs.Gt(x) +} + +// AbsLt returns true if |z| < x, where x is a uint256 +func (z *Int) AbsLt(x *uint256.Uint) bool { + return z.abs.Lt(x) +} diff --git a/contract/p/gnoswap/int256/absolute_test.gno b/contract/p/gnoswap/int256/absolute_test.gno new file mode 100644 index 000000000..05d801310 --- /dev/null +++ b/contract/p/gnoswap/int256/absolute_test.gno @@ -0,0 +1,105 @@ +package int256 + +import ( + "testing" + + "gno.land/p/gnoswap/uint256" +) + +func TestAbs(t *testing.T) { + tests := []struct { + x, want string + }{ + {"0", "0"}, + {"1", "1"}, + {"-1", "1"}, + {"-2", "2"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "115792089237316195423570985008687907853269984665640564039457584007913129639935"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + got := x.Abs() + + if got.ToString() != tc.want { + t.Errorf("Abs(%s) = %v, want %v", tc.x, got.ToString(), tc.want) + } + } +} + +func TestAbsGt(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "0", "false"}, + {"1", "0", "true"}, + {"-1", "0", "true"}, + {"-1", "1", "false"}, + {"-2", "1", "true"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "0", "true"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "1", "true"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "115792089237316195423570985008687907853269984665640564039457584007913129639935", "false"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := uint256.FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.AbsGt(y) + + if got != (tc.want == "true") { + t.Errorf("AbsGt(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestAbsLt(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "0", "false"}, + {"1", "0", "false"}, + {"-1", "0", "false"}, + {"-1", "1", "false"}, + {"-2", "1", "false"}, + {"-5", "10", "true"}, + {"31330", "31337", "true"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "0", "false"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "1", "false"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "115792089237316195423570985008687907853269984665640564039457584007913129639935", "false"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := uint256.FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.AbsLt(y) + + if got != (tc.want == "true") { + t.Errorf("AbsLt(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} diff --git a/contract/p/gnoswap/int256/arithmetic.gno b/contract/p/gnoswap/int256/arithmetic.gno new file mode 100644 index 000000000..5ffb66166 --- /dev/null +++ b/contract/p/gnoswap/int256/arithmetic.gno @@ -0,0 +1,203 @@ +package int256 + +import "gno.land/p/gnoswap/uint256" + +func (z *Int) Add(x, y *Int) *Int { + z.initiateAbs() + + if x.neg == y.neg { + // If both numbers have the same sign, add their absolute values + z.abs.Add(x.abs, y.abs) + z.neg = x.neg + } else { + // If signs are different, subtract the smaller absolute value from the larger + if x.abs.Cmp(y.abs) >= 0 { + z.abs.Sub(x.abs, y.abs) + z.neg = x.neg + } else { + z.abs.Sub(y.abs, x.abs) + z.neg = y.neg + } + } + + // Ensure zero is always positive + if z.abs.IsZero() { + z.neg = false + } + + return z +} + +// AddUint256 set z to the sum x + y, where y is a uint256, and returns z +func (z *Int) AddUint256(x *Int, y *uint256.Uint) *Int { + if x.neg { + if x.abs.Gt(y) { + z.abs.Sub(x.abs, y) + z.neg = true + } else { + z.abs.Sub(y, x.abs) + z.neg = false + } + } else { + z.abs.Add(x.abs, y) + z.neg = false + } + return z +} + +// Sets z to the sum x + y, where z and x are uint256s and y is an int256. +func AddDelta(z, x *uint256.Uint, y *Int) { + if y.neg { + z.Sub(x, y.abs) + } else { + z.Add(x, y.abs) + } +} + +// Sets z to the sum x + y, where z and x are uint256s and y is an int256. +func AddDeltaOverflow(z, x *uint256.Uint, y *Int) bool { + var overflow bool + if y.neg { + _, overflow = z.SubOverflow(x, y.abs) + } else { + _, overflow = z.AddOverflow(x, y.abs) + } + return overflow +} + +// Sub sets z to the difference x-y and returns z. +func (z *Int) Sub(x, y *Int) *Int { + z.initiateAbs() + + if x.neg != y.neg { + // If sign are different, add the absolute values + z.abs.Add(x.abs, y.abs) + z.neg = x.neg + } else { + // If signs are the same, subtract the smaller absolute value from the larger + if x.abs.Cmp(y.abs) >= 0 { + z.abs = z.abs.Sub(x.abs, y.abs) + z.neg = x.neg + } else { + z.abs.Sub(y.abs, x.abs) + z.neg = !x.neg + } + } + + // Ensure zero is always positive + if z.abs.IsZero() { + z.neg = false + } + return z +} + +// SubUint256 set z to the difference x - y, where y is a uint256, and returns z +func (z *Int) SubUint256(x *Int, y *uint256.Uint) *Int { + if x.neg { + z.abs.Add(x.abs, y) + z.neg = true + } else { + if x.abs.Lt(y) { + z.abs.Sub(y, x.abs) + z.neg = true + } else { + z.abs.Sub(x.abs, y) + z.neg = false + } + } + return z +} + +// Mul sets z to the product x*y and returns z. +func (z *Int) Mul(x, y *Int) *Int { + z.initiateAbs() + + z.abs = z.abs.Mul(x.abs, y.abs) + z.neg = x.neg != y.neg && !z.abs.IsZero() // 0 has no sign + return z +} + +// MulUint256 sets z to the product x*y, where y is a uint256, and returns z +func (z *Int) MulUint256(x *Int, y *uint256.Uint) *Int { + z.abs.Mul(x.abs, y) + if z.abs.IsZero() { + z.neg = false + } else { + z.neg = x.neg + } + return z +} + +// Div sets z to the quotient x/y for y != 0 and returns z. +func (z *Int) Div(x, y *Int) *Int { + z.initiateAbs() + + if y.abs.IsZero() { + panic("division by zero") + } + + z.abs.Div(x.abs, y.abs) + z.neg = (x.neg != y.neg) && !z.abs.IsZero() // 0 has no sign + + return z +} + +// DivUint256 sets z to the quotient x/y, where y is a uint256, and returns z +// If y == 0, z is set to 0 +func (z *Int) DivUint256(x *Int, y *uint256.Uint) *Int { + z.abs.Div(x.abs, y) + if z.abs.IsZero() { + z.neg = false + } else { + z.neg = x.neg + } + return z +} + +// Quo sets z to the quotient x/y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// OBS: differs from mempooler int256, we need to panic manually if y == 0 +// Quo implements truncated division (like Go); see QuoRem for more details. +func (z *Int) Quo(x, y *Int) *Int { + if y.IsZero() { + panic("division by zero") + } + + z.initiateAbs() + + z.abs = z.abs.Div(x.abs, y.abs) + z.neg = !(z.abs.IsZero()) && x.neg != y.neg // 0 has no sign + return z +} + +// Rem sets z to the remainder x%y for y != 0 and returns z. +// If y == 0, a division-by-zero run-time panic occurs. +// OBS: differs from mempooler int256, we need to panic manually if y == 0 +// Rem implements truncated modulus (like Go); see QuoRem for more details. +func (z *Int) Rem(x, y *Int) *Int { + if y.IsZero() { + panic("division by zero") + } + + z.initiateAbs() + + z.abs.Mod(x.abs, y.abs) + z.neg = z.abs.Sign() > 0 && x.neg // 0 has no sign + return z +} + +// Mod sets z to the modulus x%y for y != 0 and returns z. +// If y == 0, z is set to 0 (OBS: differs from the big.Int) +func (z *Int) Mod(x, y *Int) *Int { + if x.neg { + z.abs.Div(x.abs, y.abs) + z.abs.Add(z.abs, one) + z.abs.Mul(z.abs, y.abs) + z.abs.Sub(z.abs, x.abs) + z.abs.Mod(z.abs, y.abs) + } else { + z.abs.Mod(x.abs, y.abs) + } + z.neg = false + return z +} diff --git a/contract/p/gnoswap/int256/arithmetic_test.gno b/contract/p/gnoswap/int256/arithmetic_test.gno new file mode 100644 index 000000000..019c33872 --- /dev/null +++ b/contract/p/gnoswap/int256/arithmetic_test.gno @@ -0,0 +1,571 @@ +package int256 + +import ( + "testing" + + "gno.land/p/gnoswap/uint256" +) + +func TestAdd(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "1", "1"}, + {"1", "0", "1"}, + {"1", "1", "2"}, + {"1", "2", "3"}, + // NEGATIVE + {"-1", "1", "0"}, + {"1", "-1", "0"}, + {"3", "-3", "0"}, + {"-1", "-1", "-2"}, + {"-1", "-2", "-3"}, + {"-1", "3", "2"}, + {"3", "-1", "2"}, + // OVERFLOW + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "1", "0"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.Add(x, y) + + if got.Neq(want) { + t.Errorf("Add(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestAddUint256(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "1", "1"}, + {"1", "0", "1"}, + {"1", "1", "2"}, + {"1", "2", "3"}, + {"-1", "1", "0"}, + {"-1", "3", "2"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639934", "115792089237316195423570985008687907853269984665640564039457584007913129639935", "1"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "115792089237316195423570985008687907853269984665640564039457584007913129639934", "-1"}, + // OVERFLOW + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "115792089237316195423570985008687907853269984665640564039457584007913129639935", "0"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := uint256.FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.AddUint256(x, y) + + if got.Neq(want) { + t.Errorf("AddUint256(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestAddDelta(t *testing.T) { + tests := []struct { + z, x, y, want string + }{ + {"0", "0", "0", "0"}, + {"0", "0", "1", "1"}, + {"0", "1", "0", "1"}, + {"0", "1", "1", "2"}, + {"1", "2", "3", "5"}, + {"5", "10", "-3", "7"}, + // underflow + {"1", "2", "-3", "115792089237316195423570985008687907853269984665640564039457584007913129639935"}, + } + + for _, tc := range tests { + z, err := uint256.FromDecimal(tc.z) + if err != nil { + t.Error(err) + continue + } + + x, err := uint256.FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := uint256.FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + AddDelta(z, x, y) + + if z.Neq(want) { + t.Errorf("AddDelta(%s, %s, %s) = %v, want %v", tc.z, tc.x, tc.y, z.ToString(), want.ToString()) + } + } +} + +func TestAddDeltaOverflow(t *testing.T) { + tests := []struct { + z, x, y string + want bool + }{ + {"0", "0", "0", false}, + // underflow + {"1", "2", "-3", true}, + } + + for _, tc := range tests { + z, err := uint256.FromDecimal(tc.z) + if err != nil { + t.Error(err) + continue + } + + x, err := uint256.FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + result := AddDeltaOverflow(z, x, y) + if result != tc.want { + t.Errorf("AddDeltaOverflow(%s, %s, %s) = %v, want %v", tc.z, tc.x, tc.y, result, tc.want) + } + } +} + +func TestSub(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"1", "0", "1"}, + {"1", "1", "0"}, + {"-1", "1", "-2"}, + {"1", "-1", "2"}, + {"-1", "-1", "0"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "-115792089237316195423570985008687907853269984665640564039457584007913129639935", "0"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "0", "-115792089237316195423570985008687907853269984665640564039457584007913129639935"}, + {x: "-115792089237316195423570985008687907853269984665640564039457584007913129639935", y: "1", want: "0"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.Sub(x, y) + + if got.Neq(want) { + t.Errorf("Sub(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestSubUint256(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "1", "-1"}, + {"1", "0", "1"}, + {"1", "1", "0"}, + {"1", "2", "-1"}, + {"-1", "1", "-2"}, + {"-1", "3", "-4"}, + // underflow + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "1", "-0"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "2", "-1"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "3", "-2"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := uint256.FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.SubUint256(x, y) + + if got.Neq(want) { + t.Errorf("SubUint256(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestMul(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"5", "3", "15"}, + {"-5", "3", "-15"}, + {"5", "-3", "-15"}, + {"0", "3", "0"}, + {"3", "0", "0"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.Mul(x, y) + + if got.Neq(want) { + t.Errorf("Mul(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestMulUint256(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "1", "0"}, + {"1", "0", "0"}, + {"1", "1", "1"}, + {"1", "2", "2"}, + {"-1", "1", "-1"}, + {"-1", "3", "-3"}, + {"3", "4", "12"}, + {"-3", "4", "-12"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639934", "2", "-115792089237316195423570985008687907853269984665640564039457584007913129639932"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639934", "2", "115792089237316195423570985008687907853269984665640564039457584007913129639932"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := uint256.FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.MulUint256(x, y) + + if got.Neq(want) { + t.Errorf("MulUint256(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestDiv(t *testing.T) { + tests := []struct { + x, y, expected string + }{ + {"1", "1", "1"}, + {"0", "1", "0"}, + {"-1", "1", "-1"}, + {"1", "-1", "-1"}, + {"-1", "-1", "1"}, + {"-6", "3", "-2"}, + {"10", "-2", "-5"}, + {"-10", "3", "-3"}, + {"7", "3", "2"}, + {"-7", "3", "-2"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "2", "57896044618658097711785492504343953926634992332820282019728792003956564819967"}, // Max uint256 / 2 + } + + for _, tt := range tests { + t.Run(tt.x+"/"+tt.y, func(t *testing.T) { + x := MustFromDecimal(tt.x) + y := MustFromDecimal(tt.y) + result := Zero().Div(x, y) + if result.ToString() != tt.expected { + t.Errorf("Div(%s, %s) = %s, want %s", tt.x, tt.y, result.ToString(), tt.expected) + } + if result.abs.IsZero() && result.neg { + t.Errorf("Div(%s, %s) resulted in negative zero", tt.x, tt.y) + } + }) + } + + t.Run("Division by zero", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Div(1, 0) did not panic") + } + }() + x := MustFromDecimal("1") + y := MustFromDecimal("0") + Zero().Div(x, y) + }) +} + +func TestDivUint256(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "1", "0"}, + {"1", "0", "0"}, + {"1", "1", "1"}, + {"1", "2", "0"}, + {"-1", "1", "-1"}, + {"-1", "3", "0"}, + {"4", "3", "1"}, + {"25", "5", "5"}, + {"25", "4", "6"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639934", "2", "-57896044618658097711785492504343953926634992332820282019728792003956564819967"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639934", "2", "57896044618658097711785492504343953926634992332820282019728792003956564819967"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := uint256.FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.DivUint256(x, y) + + if got.Neq(want) { + t.Errorf("DivUint256(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestQuo(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "1", "0"}, + {"0", "-1", "0"}, + {"10", "1", "10"}, + {"10", "-1", "-10"}, + {"-10", "1", "-10"}, + {"-10", "-1", "10"}, + {"10", "-3", "-3"}, + {"10", "3", "3"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.Quo(x, y) + + if got.Neq(want) { + t.Errorf("Quo(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestRem(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "1", "0"}, + {"0", "-1", "0"}, + {"10", "1", "0"}, + {"10", "-1", "0"}, + {"-10", "1", "0"}, + {"-10", "-1", "0"}, + {"10", "3", "1"}, + {"10", "-3", "1"}, + {"-10", "3", "-1"}, + {"-10", "-3", "-1"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.Rem(x, y) + + if got.Neq(want) { + t.Errorf("Rem(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestMod(t *testing.T) { + tests := []struct { + x, y, want string + }{ + {"0", "1", "0"}, + {"0", "-1", "0"}, + {"10", "0", "0"}, + {"10", "1", "0"}, + {"10", "-1", "0"}, + {"-10", "0", "0"}, + {"-10", "1", "0"}, + {"-10", "-1", "0"}, + {"10", "3", "1"}, + {"10", "-3", "1"}, + {"-10", "3", "2"}, + {"-10", "-3", "2"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.Mod(x, y) + + if got.Neq(want) { + t.Errorf("Mod(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} diff --git a/contract/p/gnoswap/int256/bitwise.gno b/contract/p/gnoswap/int256/bitwise.gno new file mode 100644 index 000000000..b7c4120dd --- /dev/null +++ b/contract/p/gnoswap/int256/bitwise.gno @@ -0,0 +1,94 @@ +package int256 + +import ( + "gno.land/p/gnoswap/uint256" +) + +// Or sets z = x | y and returns z. +func (z *Int) Or(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1) + x1 := new(uint256.Uint).Sub(x.abs, one) + y1 := new(uint256.Uint).Sub(y.abs, one) + z.abs = z.abs.Add(z.abs.And(x1, y1), one) + z.neg = true // z cannot be zero if x and y are negative + return z + } + + // x | y == x | y + z.abs = z.abs.Or(x.abs, y.abs) + z.neg = false + return z + } + + // x.neg != y.neg + if x.neg { + x, y = y, x // | is symmetric + } + + // x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1) + y1 := new(uint256.Uint).Sub(y.abs, one) + z.abs = z.abs.Add(z.abs.AndNot(y1, x.abs), one) + z.neg = true // z cannot be zero if one of x or y is negative + + return z +} + +// And sets z = x & y and returns z. +func (z *Int) And(x, y *Int) *Int { + if x.neg == y.neg { + if x.neg { + // (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1) + x1 := new(uint256.Uint).Sub(x.abs, one) + y1 := new(uint256.Uint).Sub(y.abs, one) + z.abs = z.abs.Add(z.abs.Or(x1, y1), one) + z.neg = true // z cannot be zero if x and y are negative + return z + } + + // x & y == x & y + z.abs = z.abs.And(x.abs, y.abs) + z.neg = false + return z + } + + // x.neg != y.neg + // REF: https://cs.opensource.google/go/go/+/refs/tags/go1.22.1:src/math/big/int.go;l=1192-1202;drc=d57303e65f00b84b528ee682747dbe1fd3316d30 + if x.neg { + x, y = y, x // & is symmetric + } + + // x & (-y) == x & ^(y-1) == x &^ (y-1) + y1 := new(uint256.Uint).Sub(y.abs, uint256.One()) + z.abs = z.abs.AndNot(x.abs, y1) + z.neg = false + return z +} + +// Rsh sets z = x >> n and returns z. +// OBS: Different from original implementation it was using math.Big +func (z *Int) Rsh(x *Int, n uint) *Int { + if !x.neg { + z.abs.Rsh(x.abs, n) + z.neg = x.neg + return z + } + + // REF: https://cs.opensource.google/go/go/+/refs/tags/go1.22.1:src/math/big/int.go;l=1118-1126;drc=d57303e65f00b84b528ee682747dbe1fd3316d30 + t := NewInt(0).Sub(FromUint256(x.abs), NewInt(1)) + t = t.Rsh(t, n) + + _tmp := t.Add(t, NewInt(1)) + z.abs = _tmp.Abs() + z.neg = true + + return z +} + +// Lsh sets z = x << n and returns z. +func (z *Int) Lsh(x *Int, n uint) *Int { + z.abs.Lsh(x.abs, n) + z.neg = x.neg + return z +} diff --git a/contract/p/gnoswap/int256/bitwise_test.gno b/contract/p/gnoswap/int256/bitwise_test.gno new file mode 100644 index 000000000..904ee36cc --- /dev/null +++ b/contract/p/gnoswap/int256/bitwise_test.gno @@ -0,0 +1,201 @@ +package int256 + +import ( + "gno.land/p/gnoswap/uint256" +) + +import ( + "testing" +) + +func TestOr(t *testing.T) { + tests := []struct { + name string + x, y, want Int + }{ + { + name: "all zeroes", + x: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + }, + { + name: "all ones", + x: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + }, + { + name: "mixed", + x: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + }, + { + name: "one operand all ones", + x: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := New() + got.Or(&tc.x, &tc.y) + + if got.Neq(&tc.want) { + t.Errorf("Or(%v, %v) = %v, want %v", tc.x, tc.y, got, tc.want) + } + }) + } +} + +func TestAnd(t *testing.T) { + tests := []struct { + name string + x, y, want Int + }{ + { + name: "all zeroes", + x: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + }, + { + name: "all ones", + x: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + }, + { + name: "mixed", + x: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + }, + { + name: "mixed 2", + x: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + }, + { + name: "mixed 3", + x: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + }, + { + name: "one operand zero", + x: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{0, 0, 0, 0}}, neg: false}, + }, + { + name: "one operand all ones", + x: Int{abs: &uint256.Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, neg: false}, + y: Int{abs: &uint256.Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, neg: false}, + want: Int{abs: &uint256.Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, neg: false}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := New() + got.And(&tc.x, &tc.y) + + if got.Neq(&tc.want) { + t.Errorf("And(%v, %v) = %v, want %v", tc.x, tc.y, got, tc.want) + } + }) + } +} + +func TestRsh(t *testing.T) { + tests := []struct { + x string + n uint + want string + }{ + {"1024", 0, "1024"}, + {"1024", 1, "512"}, + {"1024", 2, "256"}, + {"1024", 10, "1"}, + {"1024", 11, "0"}, + {"18446744073709551615", 0, "18446744073709551615"}, + {"18446744073709551615", 1, "9223372036854775807"}, + {"18446744073709551615", 62, "3"}, + {"18446744073709551615", 63, "1"}, + {"18446744073709551615", 64, "0"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", 0, "115792089237316195423570985008687907853269984665640564039457584007913129639935"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", 1, "57896044618658097711785492504343953926634992332820282019728792003956564819967"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", 128, "340282366920938463463374607431768211455"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", 255, "1"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", 256, "0"}, + {"-1024", 0, "-1024"}, + {"-1024", 1, "-512"}, + {"-1024", 2, "-256"}, + {"-1024", 10, "-1"}, + {"-1024", 10, "-1"}, + {"-9223372036854775808", 0, "-9223372036854775808"}, + {"-9223372036854775808", 1, "-4611686018427387904"}, + {"-9223372036854775808", 62, "-2"}, + {"-9223372036854775808", 63, "-1"}, + {"-9223372036854775808", 64, "-1"}, + {"-57896044618658097711785492504343953926634992332820282019728792003956564819968", 0, "-57896044618658097711785492504343953926634992332820282019728792003956564819968"}, + {"-57896044618658097711785492504343953926634992332820282019728792003956564819968", 1, "-28948022309329048855892746252171976963317496166410141009864396001978282409984"}, + {"-57896044618658097711785492504343953926634992332820282019728792003956564819968", 253, "-4"}, + {"-57896044618658097711785492504343953926634992332820282019728792003956564819968", 254, "-2"}, + {"-57896044618658097711785492504343953926634992332820282019728792003956564819968", 255, "-1"}, + {"-57896044618658097711785492504343953926634992332820282019728792003956564819968", 256, "-1"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.Rsh(x, tc.n) + + if got.ToString() != tc.want { + t.Errorf("Rsh(%s, %d) = %v, want %v", tc.x, tc.n, got.ToString(), tc.want) + } + } +} + +func TestLsh(t *testing.T) { + tests := []struct { + x string + n uint + want string + }{ + {"1", 0, "1"}, + {"1", 1, "2"}, + {"1", 2, "4"}, + {"2", 0, "2"}, + {"2", 1, "4"}, + {"2", 2, "8"}, + {"-2", 0, "-2"}, + {"-4", 0, "-4"}, + {"-8", 0, "-8"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + got := New() + got.Lsh(x, tc.n) + + if got.ToString() != tc.want { + t.Errorf("Lsh(%s, %d) = %v, want %v", tc.x, tc.n, got.ToString(), tc.want) + } + } +} diff --git a/contract/p/gnoswap/int256/cmp.gno b/contract/p/gnoswap/int256/cmp.gno new file mode 100644 index 000000000..6371def6f --- /dev/null +++ b/contract/p/gnoswap/int256/cmp.gno @@ -0,0 +1,101 @@ +package int256 + +// Eq returns true if z == x +func (z *Int) Eq(x *Int) bool { + if z == nil || x == nil { + panic("int256: comparing with nil") + } + return (z.neg == x.neg) && z.abs.Eq(x.abs) +} + +// Neq returns true if z != x +func (z *Int) Neq(x *Int) bool { + if z == nil || x == nil { + panic("int256: comparing with nil") + } + return !z.Eq(x) +} + +// Cmp compares x and y and returns: +// +// -1 if x < y +// 0 if x == y +// +1 if x > y +func (z *Int) Cmp(x *Int) (r int) { + if z == nil || x == nil { + panic("int256: comparing with nil") + } + // x cmp y == x cmp y + // x cmp (-y) == x + // (-x) cmp y == y + // (-x) cmp (-y) == -(x cmp y) + switch { + case z == x: + // nothing to do + case z.neg == x.neg: + r = z.abs.Cmp(x.abs) + if z.neg { + r = -r + } + case z.neg: + r = -1 + default: + r = 1 + } + return +} + +// IsZero returns true if z == 0 +func (z *Int) IsZero() bool { + return z.abs.IsZero() +} + +// IsNeg returns true if z < 0 +func (z *Int) IsNeg() bool { + return z.neg +} + +// Lt returns true if z < x +func (z *Int) Lt(x *Int) bool { + if z == nil || x == nil { + panic("int256: comparing with nil") + } + if z.neg { + if x.neg { + return z.abs.Gt(x.abs) + } else { + return true + } + } else { + if x.neg { + return false + } else { + return z.abs.Lt(x.abs) + } + } +} + +// Gt returns true if z > x +func (z *Int) Gt(x *Int) bool { + if z == nil || x == nil { + panic("int256: comparing with nil") + } + if z.neg { + if x.neg { + return z.abs.Lt(x.abs) + } else { + return false + } + } else { + if x.neg { + return true + } else { + return z.abs.Gt(x.abs) + } + } +} + +// Clone creates a new Int identical to z +func (z *Int) Clone() *Int { + return &Int{z.abs.Clone(), z.neg} +} diff --git a/contract/p/gnoswap/int256/cmp_test.gno b/contract/p/gnoswap/int256/cmp_test.gno new file mode 100644 index 000000000..221c1eab9 --- /dev/null +++ b/contract/p/gnoswap/int256/cmp_test.gno @@ -0,0 +1,317 @@ +package int256 + +import ( + "testing" +) + +func TestEq(t *testing.T) { + tests := []struct { + x, y string + want bool + }{ + {"0", "0", true}, + {"0", "1", false}, + {"1", "0", false}, + {"-1", "0", false}, + {"0", "-1", false}, + {"1", "1", true}, + {"-1", "-1", true}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "-115792089237316195423570985008687907853269984665640564039457584007913129639935", false}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "-115792089237316195423570985008687907853269984665640564039457584007913129639935", true}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Eq(y) + if got != tc.want { + t.Errorf("Eq(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestNeq(t *testing.T) { + tests := []struct { + x, y string + want bool + }{ + {"0", "0", false}, + {"0", "1", true}, + {"1", "0", true}, + {"-1", "0", true}, + {"0", "-1", true}, + {"1", "1", false}, + {"-1", "-1", false}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "-115792089237316195423570985008687907853269984665640564039457584007913129639935", true}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935", "-115792089237316195423570985008687907853269984665640564039457584007913129639935", false}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Neq(y) + if got != tc.want { + t.Errorf("Neq(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestCmp(t *testing.T) { + tests := []struct { + x, y string + want int + }{ + {"0", "0", 0}, + {"0", "1", -1}, + {"1", "0", 1}, + {"-1", "0", -1}, + {"0", "-1", 1}, + {"1", "1", 0}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "-115792089237316195423570985008687907853269984665640564039457584007913129639935", 1}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Cmp(y) + if got != tc.want { + t.Errorf("Cmp(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestIsZero(t *testing.T) { + tests := []struct { + x string + want bool + }{ + {"0", true}, + {"-0", true}, + {"1", false}, + {"-1", false}, + {"10", false}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + got := x.IsZero() + if got != tc.want { + t.Errorf("IsZero(%s) = %v, want %v", tc.x, got, tc.want) + } + } +} + +func TestIsNeg(t *testing.T) { + tests := []struct { + x string + want bool + }{ + {"0", false}, + {"-0", true}, // TODO: should this be false? + {"1", false}, + {"-1", true}, + {"10", false}, + {"-10", true}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + got := x.IsNeg() + if got != tc.want { + t.Errorf("IsNeg(%s) = %v, want %v", tc.x, got, tc.want) + } + } +} + +func TestLt(t *testing.T) { + tests := []struct { + x, y string + want bool + }{ + {"0", "0", false}, + {"0", "1", true}, + {"1", "0", false}, + {"-1", "0", true}, + {"0", "-1", false}, + {"1", "1", false}, + {"-1", "-1", false}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "-115792089237316195423570985008687907853269984665640564039457584007913129639935", false}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Lt(y) + if got != tc.want { + t.Errorf("Lt(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestGt(t *testing.T) { + tests := []struct { + x, y string + want bool + }{ + {"0", "0", false}, + {"0", "1", false}, + {"1", "0", true}, + {"-1", "0", false}, + {"0", "-1", true}, + {"1", "1", false}, + {"-1", "-1", false}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "-115792089237316195423570985008687907853269984665640564039457584007913129639935", true}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Gt(y) + if got != tc.want { + t.Errorf("Gt(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestClone(t *testing.T) { + tests := []struct { + x string + }{ + {"0"}, + {"-0"}, + {"1"}, + {"-1"}, + {"10"}, + {"-10"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935"}, + {"-115792089237316195423570985008687907853269984665640564039457584007913129639935"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y := x.Clone() + + if x.Cmp(y) != 0 { + t.Errorf("Clone(%s) = %v, want %v", tc.x, y, x) + } + } +} + + +func TestNilChecks(t *testing.T) { + validInt := NewInt(123) + + tests := []struct { + name string + fn func() + wantPanic string + }{ + { + name: "Eq with nil", + fn: func() { validInt.Eq(nil) }, + wantPanic: "int256: comparing with nil", + }, + { + name: "Neq with nil", + fn: func() { validInt.Neq(nil) }, + wantPanic: "int256: comparing with nil", + }, + { + name: "Cmp with nil", + fn: func() { validInt.Cmp(nil) }, + wantPanic: "int256: comparing with nil", + }, + { + name: "Lt with nil", + fn: func() { validInt.Lt(nil) }, + wantPanic: "int256: comparing with nil", + }, + { + name: "Gt with nil", + fn: func() { validInt.Gt(nil) }, + wantPanic: "int256: comparing with nil", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Errorf("%s: expected panic but got none", tt.name) + return + } + if r.(string) != tt.wantPanic { + t.Errorf("%s: got panic %v, want %v", tt.name, r, tt.wantPanic) + } + }() + + tt.fn() + }) + } +} diff --git a/contract/p/gnoswap/int256/conversion.gno b/contract/p/gnoswap/int256/conversion.gno new file mode 100644 index 000000000..23b9b7791 --- /dev/null +++ b/contract/p/gnoswap/int256/conversion.gno @@ -0,0 +1,88 @@ +package int256 + +import ( + "gno.land/p/gnoswap/uint256" +) + +// SetInt64 sets z to x and returns z. +func (z *Int) SetInt64(x int64) *Int { + z.initiateAbs() + + neg := false + if x < 0 { + neg = true + x = -x + } + if z.abs == nil { + panic("int256_SetInt64()__abs is nil") + } + z.abs = z.abs.SetUint64(uint64(x)) + z.neg = neg + return z +} + +// SetUint64 sets z to x and returns z. +func (z *Int) SetUint64(x uint64) *Int { + z.initiateAbs() + + if z.abs == nil { + panic("int256_SetUint64()__abs is nil") + } + z.abs = z.abs.SetUint64(x) + z.neg = false + return z +} + +// Uint64 returns the lower 64-bits of z +func (z *Int) Uint64() uint64 { + return z.abs.Uint64() +} + +// Int64 returns the lower 64-bits of z +func (z *Int) Int64() int64 { + _abs := z.abs.Clone() + + if z.neg { + return -int64(_abs.Uint64()) + } + return int64(_abs.Uint64()) +} + +// Neg sets z to -x and returns z.) +func (z *Int) Neg(x *Int) *Int { + z.abs.Set(x.abs) + if z.abs.IsZero() { + z.neg = false + } else { + z.neg = !x.neg + } + return z +} + +// Set sets z to x and returns z. +func (z *Int) Set(x *Int) *Int { + z.abs.Set(x.abs) + z.neg = x.neg + return z +} + +// SetFromUint256 converts a uint256.Uint to Int and sets the value to z. +func (z *Int) SetUint256(x *uint256.Uint) *Int { + z.abs.Set(x) + z.neg = false + return z +} + +// OBS, differs from original mempooler int256 +// ToString returns the decimal representation of z. +func (z *Int) ToString() string { + if z == nil { + panic("int256: nil pointer to ToString()") + } + + t := z.abs.Dec() + if z.neg { + return "-" + t + } + return t +} diff --git a/contract/p/gnoswap/int256/conversion_test.gno b/contract/p/gnoswap/int256/conversion_test.gno new file mode 100644 index 000000000..4e87df6e0 --- /dev/null +++ b/contract/p/gnoswap/int256/conversion_test.gno @@ -0,0 +1,234 @@ +package int256 + +import ( + "testing" + + "gno.land/p/gnoswap/uint256" +) + +func TestSetInt64(t *testing.T) { + tests := []struct { + x int64 + want string + }{ + {0, "0"}, + {1, "1"}, + {-1, "-1"}, + {9223372036854775807, "9223372036854775807"}, + {-9223372036854775808, "-9223372036854775808"}, + } + + for _, tc := range tests { + var z Int + z.SetInt64(tc.x) + + got := z.ToString() + if got != tc.want { + t.Errorf("SetInt64(%d) = %s, want %s", tc.x, got, tc.want) + } + } +} + +func TestSetUint64(t *testing.T) { + tests := []struct { + x uint64 + want string + }{ + {0, "0"}, + {1, "1"}, + } + + for _, tc := range tests { + var z Int + z.SetUint64(tc.x) + + got := z.ToString() + if got != tc.want { + t.Errorf("SetUint64(%d) = %s, want %s", tc.x, got, tc.want) + } + } +} + +func TestUint64(t *testing.T) { + tests := []struct { + x string + want uint64 + }{ + {"0", 0}, + {"1", 1}, + {"9223372036854775807", 9223372036854775807}, + {"9223372036854775808", 9223372036854775808}, + {"18446744073709551615", 18446744073709551615}, + {"18446744073709551616", 0}, + {"18446744073709551617", 1}, + {"-1", 1}, + {"-18446744073709551615", 18446744073709551615}, + {"-18446744073709551616", 0}, + {"-18446744073709551617", 1}, + } + + for _, tc := range tests { + z := MustFromDecimal(tc.x) + + got := z.Uint64() + if got != tc.want { + t.Errorf("Uint64(%s) = %d, want %d", tc.x, got, tc.want) + } + } +} + +func TestInt64(t *testing.T) { + tests := []struct { + x string + want int64 + }{ + {"0", 0}, + {"1", 1}, + {"9223372036854775807", 9223372036854775807}, + {"18446744073709551616", 0}, + {"18446744073709551617", 1}, + {"-1", -1}, + {"-9223372036854775808", -9223372036854775808}, + } + + for _, tc := range tests { + z := MustFromDecimal(tc.x) + + got := z.Int64() + if got != tc.want { + t.Errorf("Uint64(%s) = %d, want %d", tc.x, got, tc.want) + } + } +} + +func TestNeg(t *testing.T) { + tests := []struct { + x string + want string + }{ + {"0", "0"}, + {"1", "-1"}, + {"-1", "1"}, + {"9223372036854775807", "-9223372036854775807"}, + {"-18446744073709551615", "18446744073709551615"}, + } + + for _, tc := range tests { + z := MustFromDecimal(tc.x) + z.Neg(z) + + got := z.ToString() + if got != tc.want { + t.Errorf("Neg(%s) = %s, want %s", tc.x, got, tc.want) + } + } +} + +func TestSet(t *testing.T) { + tests := []struct { + x string + want string + }{ + {"0", "0"}, + {"1", "1"}, + {"-1", "-1"}, + {"9223372036854775807", "9223372036854775807"}, + {"-18446744073709551615", "-18446744073709551615"}, + } + + for _, tc := range tests { + z := MustFromDecimal(tc.x) + z.Set(z) + + got := z.ToString() + if got != tc.want { + t.Errorf("Set(%s) = %s, want %s", tc.x, got, tc.want) + } + } +} + +func TestSetUint256(t *testing.T) { + tests := []struct { + x string + want string + }{ + {"0", "0"}, + {"1", "1"}, + {"9223372036854775807", "9223372036854775807"}, + {"18446744073709551615", "18446744073709551615"}, + } + + for _, tc := range tests { + got := New() + + z := uint256.MustFromDecimal(tc.x) + got.SetUint256(z) + + if got.ToString() != tc.want { + t.Errorf("SetUint256(%s) = %s, want %s", tc.x, got.ToString(), tc.want) + } + } +} + +func TestToString(t *testing.T) { + tests := []struct { + name string + setup func() *Int + expected string + }{ + { + name: "Zero from subtraction", + setup: func() *Int { + minusThree := MustFromDecimal("-3") + three := MustFromDecimal("3") + return Zero().Add(minusThree, three) + }, + expected: "0", + }, + { + name: "Zero from right shift", + setup: func() *Int { + return Zero().Rsh(One(), 1234) + }, + expected: "0", + }, + { + name: "Positive number", + setup: func() *Int { + return MustFromDecimal("42") + }, + expected: "42", + }, + { + name: "Negative number", + setup: func() *Int { + return MustFromDecimal("-42") + }, + expected: "-42", + }, + { + name: "Large positive number", + setup: func() *Int { + return MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") + }, + expected: "115792089237316195423570985008687907853269984665640564039457584007913129639935", + }, + { + name: "Large negative number", + setup: func() *Int { + return MustFromDecimal("-115792089237316195423570985008687907853269984665640564039457584007913129639935") + }, + expected: "-115792089237316195423570985008687907853269984665640564039457584007913129639935", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + z := tt.setup() + result := z.ToString() + if result != tt.expected { + t.Errorf("ToString() = %s, want %s", result, tt.expected) + } + }) + } +} diff --git a/contract/p/gnoswap/int256/gno.mod b/contract/p/gnoswap/int256/gno.mod new file mode 100644 index 000000000..8ba65d41b --- /dev/null +++ b/contract/p/gnoswap/int256/gno.mod @@ -0,0 +1 @@ +module gno.land/p/gnoswap/int256 diff --git a/contract/p/gnoswap/int256/int256.gno b/contract/p/gnoswap/int256/int256.gno new file mode 100644 index 000000000..58a3e0037 --- /dev/null +++ b/contract/p/gnoswap/int256/int256.gno @@ -0,0 +1,126 @@ +// This package provides a 256-bit signed integer type, Int, and associated functions. +package int256 + +import ( + "gno.land/p/gnoswap/uint256" +) + +var one = uint256.NewUint(1) + +type Int struct { + abs *uint256.Uint + neg bool +} + +// Zero returns a new Int set to 0. +func Zero() *Int { + return NewInt(0) +} + +// One returns a new Int set to 1. +func One() *Int { + return NewInt(1) +} + +// Sign returns: +// +// -1 if x < 0 +// 0 if x == 0 +// +1 if x > 0 +func (z *Int) Sign() int { + z.initiateAbs() + + if z.abs.IsZero() { + return 0 + } + if z.neg { + return -1 + } + return 1 +} + +// New returns a new Int set to 0. +func New() *Int { + return &Int{ + abs: new(uint256.Uint), + } +} + +// NewInt allocates and returns a new Int set to x. +func NewInt(x int64) *Int { + return New().SetInt64(x) +} + +// FromDecimal returns a new Int from a decimal string. +// Returns a new Int and an error if the string is not a valid decimal. +func FromDecimal(s string) (*Int, error) { + return new(Int).SetString(s) +} + +// MustFromDecimal returns a new Int from a decimal string. +// Panics if the string is not a valid decimal. +func MustFromDecimal(s string) *Int { + z, err := FromDecimal(s) + if err != nil { + panic(err) + } + return z +} + +// SetString sets s to the value of z and returns z and a boolean indicating success. +func (z *Int) SetString(s string) (*Int, error) { + neg := false + // Remove max one leading + + if len(s) > 0 && s[0] == '+' { + neg = false + s = s[1:] + } + + if len(s) > 0 && s[0] == '-' { + neg = true + s = s[1:] + } + var ( + abs *uint256.Uint + err error + ) + abs, err = uint256.FromDecimal(s) + if err != nil { + return nil, err + } + + return &Int{ + abs, + neg, + }, nil +} + +// FromUint256 is a convenience-constructor from uint256.Uint. +// Returns a new Int and whether overflow occurred. +// OBS: If u is `nil`, this method returns `nil, false` +func FromUint256(x *uint256.Uint) *Int { + if x == nil { + return nil + } + z := Zero() + + z.SetUint256(x) + return z +} + +// OBS, differs from original mempooler int256 +// NilToZero sets z to 0 and return it if it's nil, otherwise it returns z +func (z *Int) NilToZero() *Int { + if z == nil { + return NewInt(0) + } + return z +} + +// initiateAbs sets default value for `z` or `z.abs` value if is nil +// OBS: differs from mempooler int256. It checks not only `z.abs` but also `z` +func (z *Int) initiateAbs() { + if z == nil || z.abs == nil { + z.abs = new(uint256.Uint) + } +} diff --git a/contract/p/gnoswap/int256/int256_test.gno b/contract/p/gnoswap/int256/int256_test.gno new file mode 100644 index 000000000..974cc6ed8 --- /dev/null +++ b/contract/p/gnoswap/int256/int256_test.gno @@ -0,0 +1,25 @@ +// ported from github.com/mempooler/int256 +package int256 + +import ( + "testing" +) + +func TestSign(t *testing.T) { + tests := []struct { + x string + want int + }{ + {"0", 0}, + {"1", 1}, + {"-1", -1}, + } + + for _, tc := range tests { + z := MustFromDecimal(tc.x) + got := z.Sign() + if got != tc.want { + t.Errorf("Sign(%s) = %d, want %d", tc.x, got, tc.want) + } + } +} diff --git a/contract/p/gnoswap/uint256/_helper_test.gno b/contract/p/gnoswap/uint256/_helper_test.gno new file mode 100644 index 000000000..d867bb9c9 --- /dev/null +++ b/contract/p/gnoswap/uint256/_helper_test.gno @@ -0,0 +1,39 @@ +package uint256 + +import ( + "testing" +) + +func shouldEQ(t *testing.T, got, expected interface{}) { + if got != expected { + t.Errorf("got %v, expected %v", got, expected) + } +} + +func shouldNEQ(t *testing.T, got, expected interface{}) { + if got == expected { + t.Errorf("got %v, didn't expected %v", got, expected) + } +} + +func shouldPanic(t *testing.T, f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic") + } + }() + f() +} + +func shouldPanicWithMsg(t *testing.T, f func(), msg string) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } else { + if r != msg { + t.Errorf("excepted panic(%v), got(%v)", msg, r) + } + } + }() + f() +} diff --git a/contract/p/gnoswap/uint256/arithmetic.gno b/contract/p/gnoswap/uint256/arithmetic.gno new file mode 100644 index 000000000..a10436f8e --- /dev/null +++ b/contract/p/gnoswap/uint256/arithmetic.gno @@ -0,0 +1,476 @@ +// arithmetic provides arithmetic operations for Uint objects. +// This includes basic binary operations such as addition, subtraction, multiplication, division, and modulo operations +// as well as overflow checks, and negation. These functions are essential for numeric +// calculations using 256-bit unsigned integers. +package uint256 + +import ( + "math/bits" +) + +// Add sets z to the sum x+y +func (z *Uint) Add(x, y *Uint) *Uint { + var carry uint64 + z.arr[0], carry = bits.Add64(x.arr[0], y.arr[0], 0) + z.arr[1], carry = bits.Add64(x.arr[1], y.arr[1], carry) + z.arr[2], carry = bits.Add64(x.arr[2], y.arr[2], carry) + z.arr[3], _ = bits.Add64(x.arr[3], y.arr[3], carry) + return z +} + +// AddOverflow sets z to the sum x+y, and returns z and whether overflow occurred +func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) { + var carry uint64 + z.arr[0], carry = bits.Add64(x.arr[0], y.arr[0], 0) + z.arr[1], carry = bits.Add64(x.arr[1], y.arr[1], carry) + z.arr[2], carry = bits.Add64(x.arr[2], y.arr[2], carry) + z.arr[3], carry = bits.Add64(x.arr[3], y.arr[3], carry) + return z, carry != 0 +} + +// Sub sets z to the difference x-y +func (z *Uint) Sub(x, y *Uint) *Uint { + var carry uint64 + z.arr[0], carry = bits.Sub64(x.arr[0], y.arr[0], 0) + z.arr[1], carry = bits.Sub64(x.arr[1], y.arr[1], carry) + z.arr[2], carry = bits.Sub64(x.arr[2], y.arr[2], carry) + z.arr[3], _ = bits.Sub64(x.arr[3], y.arr[3], carry) + return z +} + +// SubOverflow sets z to the difference x-y and returns z and true if the operation underflowed +func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) { + var carry uint64 + z.arr[0], carry = bits.Sub64(x.arr[0], y.arr[0], 0) + z.arr[1], carry = bits.Sub64(x.arr[1], y.arr[1], carry) + z.arr[2], carry = bits.Sub64(x.arr[2], y.arr[2], carry) + z.arr[3], carry = bits.Sub64(x.arr[3], y.arr[3], carry) + return z, carry != 0 +} + +// Neg returns -x mod 2^256. +func (z *Uint) Neg(x *Uint) *Uint { + return z.Sub(new(Uint), x) +} + +// commented out for possible overflow +// Mul sets z to the product x*y +func (z *Uint) Mul(x, y *Uint) *Uint { + var ( + res Uint + carry uint64 + res1, res2, res3 uint64 + ) + + carry, res.arr[0] = bits.Mul64(x.arr[0], y.arr[0]) + carry, res1 = umulHop(carry, x.arr[1], y.arr[0]) + carry, res2 = umulHop(carry, x.arr[2], y.arr[0]) + res3 = x.arr[3]*y.arr[0] + carry + + carry, res.arr[1] = umulHop(res1, x.arr[0], y.arr[1]) + carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry) + res3 = res3 + x.arr[2]*y.arr[1] + carry + + carry, res.arr[2] = umulHop(res2, x.arr[0], y.arr[2]) + res3 = res3 + x.arr[1]*y.arr[2] + carry + + res.arr[3] = res3 + x.arr[0]*y.arr[3] + + return z.Set(&res) +} + +// MulOverflow sets z to the product x*y, and returns z and whether overflow occurred +func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) { + p := umul(x, y) + copy(z.arr[:], p[:4]) + return z, (p[4] | p[5] | p[6] | p[7]) != 0 +} + +// commented out for possible overflow +// Div sets z to the quotient x/y for returns z. +// If y == 0, z is set to 0 +func (z *Uint) Div(x, y *Uint) *Uint { + if y.IsZero() || y.Gt(x) { + return z.Clear() + } + if x.Eq(y) { + return z.SetOne() + } + // Shortcut some cases + if x.IsUint64() { + return z.SetUint64(x.Uint64() / y.Uint64()) + } + + // At this point, we know + // x/y ; x > y > 0 + + var quot Uint + udivrem(quot.arr[:], x.arr[:], y) + return z.Set(") +} + +// MulMod calculates the modulo-m multiplication of x and y and +// returns z. +// If m == 0, z is set to 0 (OBS: differs from the big.Int) +func (z *Uint) MulMod(x, y, m *Uint) *Uint { + if x.IsZero() || y.IsZero() || m.IsZero() { + return z.Clear() + } + p := umul(x, y) + + if m.arr[3] != 0 { + mu := Reciprocal(m) + r := reduce4(p, m, mu) + return z.Set(&r) + } + + var ( + pl Uint + ph Uint + ) + + pl = Uint{arr: [4]uint64{p[0], p[1], p[2], p[3]}} + ph = Uint{arr: [4]uint64{p[4], p[5], p[6], p[7]}} + + // If the multiplication is within 256 bits use Mod(). + if ph.IsZero() { + return z.Mod(&pl, m) + } + + var quot [8]uint64 + rem := udivrem(quot[:], p[:], m) + return z.Set(&rem) +} + +// Mod sets z to the modulus x%y for y != 0 and returns z. +// If y == 0, z is set to 0 (OBS: differs from the big.Uint) +func (z *Uint) Mod(x, y *Uint) *Uint { + if x.IsZero() || y.IsZero() { + return z.Clear() + } + switch x.Cmp(y) { + case -1: + // x < y + copy(z.arr[:], x.arr[:]) + return z + case 0: + // x == y + return z.Clear() // They are equal + } + + // At this point: + // x != 0 + // y != 0 + // x > y + + // Shortcut trivial case + if x.IsUint64() { + return z.SetUint64(x.Uint64() % y.Uint64()) + } + + var quot Uint + *z = udivrem(quot.arr[:], x.arr[:], y) + return z +} + +// DivMod sets z to the quotient x div y and m to the modulus x mod y and returns the pair (z, m) for y != 0. +// If y == 0, both z and m are set to 0 (OBS: differs from the big.Int) +func (z *Uint) DivMod(x, y, m *Uint) (*Uint, *Uint) { + if y.IsZero() { + return z.Clear(), m.Clear() + } + var quot Uint + *m = udivrem(quot.arr[:], x.arr[:], y) + *z = quot + return z, m +} + +// Exp sets z = base**exponent mod 2**256, and returns z. +func (z *Uint) Exp(base, exponent *Uint) *Uint { + res := Uint{arr: [4]uint64{1, 0, 0, 0}} + multiplier := *base + expBitLen := exponent.BitLen() + + curBit := 0 + word := exponent.arr[0] + for ; curBit < expBitLen && curBit < 64; curBit++ { + if word&1 == 1 { + res.Mul(&res, &multiplier) + } + multiplier.squared() + word >>= 1 + } + + word = exponent.arr[1] + for ; curBit < expBitLen && curBit < 128; curBit++ { + if word&1 == 1 { + res.Mul(&res, &multiplier) + } + multiplier.squared() + word >>= 1 + } + + word = exponent.arr[2] + for ; curBit < expBitLen && curBit < 192; curBit++ { + if word&1 == 1 { + res.Mul(&res, &multiplier) + } + multiplier.squared() + word >>= 1 + } + + word = exponent.arr[3] + for ; curBit < expBitLen && curBit < 256; curBit++ { + if word&1 == 1 { + res.Mul(&res, &multiplier) + } + multiplier.squared() + word >>= 1 + } + return z.Set(&res) +} + +func (z *Uint) squared() { + var ( + res Uint + carry0, carry1, carry2 uint64 + res1, res2 uint64 + ) + + carry0, res.arr[0] = bits.Mul64(z.arr[0], z.arr[0]) + carry0, res1 = umulHop(carry0, z.arr[0], z.arr[1]) + carry0, res2 = umulHop(carry0, z.arr[0], z.arr[2]) + + carry1, res.arr[1] = umulHop(res1, z.arr[0], z.arr[1]) + carry1, res2 = umulStep(res2, z.arr[1], z.arr[1], carry1) + + carry2, res.arr[2] = umulHop(res2, z.arr[0], z.arr[2]) + + res.arr[3] = 2*(z.arr[0]*z.arr[3]+z.arr[1]*z.arr[2]) + carry0 + carry1 + carry2 + + z.Set(&res) +} + +// udivrem divides u by d and produces both quotient and remainder. +// The quotient is stored in provided quot - len(u)-len(d)+1 words. +// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words. +// See Knuth, Volume 2, section 4.3.1, Algorithm D. +func udivrem(quot, u []uint64, d *Uint) (rem Uint) { + var dLen int + for i := len(d.arr) - 1; i >= 0; i-- { + if d.arr[i] != 0 { + dLen = i + 1 + break + } + } + + shift := uint(bits.LeadingZeros64(d.arr[dLen-1])) + + var dnStorage Uint + dn := dnStorage.arr[:dLen] + for i := dLen - 1; i > 0; i-- { + dn[i] = (d.arr[i] << shift) | (d.arr[i-1] >> (64 - shift)) + } + dn[0] = d.arr[0] << shift + + var uLen int + for i := len(u) - 1; i >= 0; i-- { + if u[i] != 0 { + uLen = i + 1 + break + } + } + + if uLen < dLen { + copy(rem.arr[:], u) + return rem + } + + var unStorage [9]uint64 + un := unStorage[:uLen+1] + un[uLen] = u[uLen-1] >> (64 - shift) + for i := uLen - 1; i > 0; i-- { + un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift)) + } + un[0] = u[0] << shift + + // TODO: Skip the highest word of numerator if not significant. + + if dLen == 1 { + r := udivremBy1(quot, un, dn[0]) + rem.SetUint64(r >> shift) + return rem + } + + udivremKnuth(quot, un, dn) + + for i := 0; i < dLen-1; i++ { + rem.arr[i] = (un[i] >> shift) | (un[i+1] << (64 - shift)) + } + rem.arr[dLen-1] = un[dLen-1] >> shift + + return rem +} + +// umul computes full 256 x 256 -> 512 multiplication. +func umul(x, y *Uint) [8]uint64 { + var ( + res [8]uint64 + carry, carry4, carry5, carry6 uint64 + res1, res2, res3, res4, res5 uint64 + ) + + carry, res[0] = bits.Mul64(x.arr[0], y.arr[0]) + carry, res1 = umulHop(carry, x.arr[1], y.arr[0]) + carry, res2 = umulHop(carry, x.arr[2], y.arr[0]) + carry4, res3 = umulHop(carry, x.arr[3], y.arr[0]) + + carry, res[1] = umulHop(res1, x.arr[0], y.arr[1]) + carry, res2 = umulStep(res2, x.arr[1], y.arr[1], carry) + carry, res3 = umulStep(res3, x.arr[2], y.arr[1], carry) + carry5, res4 = umulStep(carry4, x.arr[3], y.arr[1], carry) + + carry, res[2] = umulHop(res2, x.arr[0], y.arr[2]) + carry, res3 = umulStep(res3, x.arr[1], y.arr[2], carry) + carry, res4 = umulStep(res4, x.arr[2], y.arr[2], carry) + carry6, res5 = umulStep(carry5, x.arr[3], y.arr[2], carry) + + carry, res[3] = umulHop(res3, x.arr[0], y.arr[3]) + carry, res[4] = umulStep(res4, x.arr[1], y.arr[3], carry) + carry, res[5] = umulStep(res5, x.arr[2], y.arr[3], carry) + res[7], res[6] = umulStep(carry6, x.arr[3], y.arr[3], carry) + + return res +} + +// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry. +func umulStep(z, x, y, carry uint64) (hi, lo uint64) { + hi, lo = bits.Mul64(x, y) + lo, carry = bits.Add64(lo, carry, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, z, 0) + hi, _ = bits.Add64(hi, 0, carry) + return hi, lo +} + +// umulHop computes (hi * 2^64 + lo) = z + (x * y) +func umulHop(z, x, y uint64) (hi, lo uint64) { + hi, lo = bits.Mul64(x, y) + lo, carry := bits.Add64(lo, z, 0) + hi, _ = bits.Add64(hi, 0, carry) + return hi, lo +} + +// udivremBy1 divides u by single normalized word d and produces both quotient and remainder. +// The quotient is stored in provided quot. +func udivremBy1(quot, u []uint64, d uint64) (rem uint64) { + reciprocal := reciprocal2by1(d) + rem = u[len(u)-1] // Set the top word as remainder. + for j := len(u) - 2; j >= 0; j-- { + quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal) + } + return rem +} + +// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm. +// The quotient is stored in provided quot - len(u)-len(d) words. +// Updates u to contain the remainder - len(d) words. +func udivremKnuth(quot, u, d []uint64) { + dh := d[len(d)-1] + dl := d[len(d)-2] + reciprocal := reciprocal2by1(dh) + + for j := len(u) - len(d) - 1; j >= 0; j-- { + u2 := u[j+len(d)] + u1 := u[j+len(d)-1] + u0 := u[j+len(d)-2] + + var qhat, rhat uint64 + if u2 >= dh { // Division overflows. + qhat = ^uint64(0) + // TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case). + } else { + qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal) + ph, pl := bits.Mul64(qhat, dl) + if ph > rhat || (ph == rhat && pl > u0) { + qhat-- + // TODO: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case). + } + } + + // Multiply and subtract. + borrow := subMulTo(u[j:], d, qhat) + u[j+len(d)] = u2 - borrow + if u2 < borrow { // Too much subtracted, add back. + qhat-- + u[j+len(d)] += addTo(u[j:], d) + } + + quot[j] = qhat // Store quotient digit. + } +} + +// isBitSet returns true if bit n-th is set, where n = 0 is LSB. +// The n must be <= 255. +func (z *Uint) isBitSet(n uint) bool { + return (z.arr[n/64] & (1 << (n % 64))) != 0 +} + +func (z *Uint) IsOverflow() bool { + return z.isBitSet(255) +} + +// addTo computes x += y. +// Requires len(x) >= len(y). +func addTo(x, y []uint64) uint64 { + var carry uint64 + for i := 0; i < len(y); i++ { + x[i], carry = bits.Add64(x[i], y[i], carry) + } + return carry +} + +// subMulTo computes x -= y * multiplier. +// Requires len(x) >= len(y). +func subMulTo(x, y []uint64, multiplier uint64) uint64 { + var borrow uint64 + for i := 0; i < len(y); i++ { + s, carry1 := bits.Sub64(x[i], borrow, 0) + ph, pl := bits.Mul64(y[i], multiplier) + t, carry2 := bits.Sub64(s, pl, 0) + x[i] = t + borrow = ph + carry1 + carry2 + } + return borrow +} + +// reciprocal2by1 computes <^d, ^0> / d. +func reciprocal2by1(d uint64) uint64 { + reciprocal, _ := bits.Div64(^d, ^uint64(0), d) + return reciprocal +} + +// udivrem2by1 divides / d and produces both quotient and remainder. +// It uses the provided d's reciprocal. +// Implementation ported from https://github.com/chfast/intx and is based on +// "Improved division by invariant integers", Algorithm 4. +func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) { + qh, ql := bits.Mul64(reciprocal, uh) + ql, carry := bits.Add64(ql, ul, 0) + qh, _ = bits.Add64(qh, uh, carry) + qh++ + + r := ul - qh*d + + if r > ql { + qh-- + r += d + } + + if r >= d { + qh++ + r -= d + } + + return qh, r +} diff --git a/contract/p/gnoswap/uint256/arithmetic_test.gno b/contract/p/gnoswap/uint256/arithmetic_test.gno new file mode 100644 index 000000000..09d83eb18 --- /dev/null +++ b/contract/p/gnoswap/uint256/arithmetic_test.gno @@ -0,0 +1,369 @@ +package uint256 + +import "testing" + +type binOp2Test struct { + x, y, want string +} + +func TestIsOverflow(t *testing.T) { + tests := []struct { + name string + input *Uint + expected bool + }{ + { + name: "Number greater than max value", + input: &Uint{arr: [4]uint64{ + ^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0), + }}, + expected: true, + }, + { + name: "Max value", + input: &Uint{arr: [4]uint64{ + ^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0) >> 1, + }}, + expected: false, + }, + { + name: "0", + input: &Uint{arr: [4]uint64{0, 0, 0, 0}}, + expected: false, + }, + { + name: "Only 255th bit set", + input: &Uint{arr: [4]uint64{ + 0, 0, 0, uint64(1) << 63, + }}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.input.IsOverflow(); got != tt.expected { + t.Errorf("IsOverflow() = %v, expected %v", got, tt.expected) + } + }) + } +} + +func TestAdd(t *testing.T) { + tests := []binOp2Test{ + {"0", "1", "1"}, + {"1", "0", "1"}, + {"1", "1", "2"}, + {"1", "3", "4"}, + {"10", "10", "20"}, + {"18446744073709551615", "18446744073709551615", "36893488147419103230"}, // uint64 overflow + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Add(x, y) + + if got.Neq(want) { + t.Errorf("Add(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestSub(t *testing.T) { + tests := []binOp2Test{ + {"1", "0", "1"}, + {"1", "1", "0"}, + {"10", "10", "0"}, + {"31337", "1337", "30000"}, + {"2", "3", "115792089237316195423570985008687907853269984665640564039457584007913129639935"}, // underflow + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Sub(x, y) + + if got.Neq(want) { + t.Errorf("Sub(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestMul(t *testing.T) { + tests := []binOp2Test{ + {"1", "0", "0"}, + {"1", "1", "1"}, + {"10", "10", "100"}, + {"18446744073709551615", "2", "36893488147419103230"}, // uint64 overflow + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Mul(x, y) + + if got.Neq(want) { + t.Errorf("Mul(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestDiv(t *testing.T) { + tests := []binOp2Test{ + {"31337", "3", "10445"}, + {"31337", "0", "0"}, + {"0", "31337", "0"}, + {"1", "1", "1"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Div(x, y) + + if got.Neq(want) { + t.Errorf("Div(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestMod(t *testing.T) { + tests := []binOp2Test{ + {"31337", "3", "2"}, + {"31337", "0", "0"}, + {"0", "31337", "0"}, + {"2", "31337", "2"}, + {"1", "1", "0"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Mod(x, y) + + if got.Neq(want) { + t.Errorf("Mod(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestDivMod(t *testing.T) { + tests := []struct { + x string + y string + wantDiv string + wantMod string + }{ + {"1", "1", "1", "0"}, + {"10", "10", "1", "0"}, + {"100", "10", "10", "0"}, + {"31337", "3", "10445", "2"}, + {"31337", "0", "0", "0"}, + {"0", "31337", "0", "0"}, + {"2", "31337", "0", "2"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + wantDiv, err := FromDecimal(tc.wantDiv) + if err != nil { + t.Error(err) + continue + } + + wantMod, err := FromDecimal(tc.wantMod) + if err != nil { + t.Error(err) + continue + } + + gotDiv := new(Uint) + gotMod := new(Uint) + gotDiv.DivMod(x, y, gotMod) + + for i := range gotDiv.arr { + if gotDiv.arr[i] != wantDiv.arr[i] { + t.Errorf("DivMod(%s, %s) got Div %v, want Div %v", tc.x, tc.y, gotDiv, wantDiv) + break + } + } + for i := range gotMod.arr { + if gotMod.arr[i] != wantMod.arr[i] { + t.Errorf("DivMod(%s, %s) got Mod %v, want Mod %v", tc.x, tc.y, gotMod, wantMod) + break + } + } + } +} + +func TestNeg(t *testing.T) { + tests := []struct { + x string + want string + }{ + {"31337", "115792089237316195423570985008687907853269984665640564039457584007913129608599"}, + {"115792089237316195423570985008687907853269984665640564039457584007913129608599", "31337"}, + {"0", "0"}, + {"2", "115792089237316195423570985008687907853269984665640564039457584007913129639934"}, + {"1", "115792089237316195423570985008687907853269984665640564039457584007913129639935"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Neg(x) + + if got.Neq(want) { + t.Errorf("Neg(%s) = %v, want %v", tc.x, got.ToString(), want.ToString()) + } + } +} + +func TestExp(t *testing.T) { + tests := []binOp2Test{ + {"31337", "3", "30773171189753"}, + {"31337", "0", "1"}, + {"0", "31337", "0"}, + {"1", "1", "1"}, + {"2", "3", "8"}, + {"2", "64", "18446744073709551616"}, + {"2", "128", "340282366920938463463374607431768211456"}, + {"2", "255", "57896044618658097711785492504343953926634992332820282019728792003956564819968"}, + {"2", "256", "0"}, // overflow + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Exp(x, y) + + if got.Neq(want) { + t.Errorf("Exp(%s, %s) = %v, want %v", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} diff --git a/contract/p/gnoswap/uint256/bits_table.gno b/contract/p/gnoswap/uint256/bits_table.gno new file mode 100644 index 000000000..53dbea948 --- /dev/null +++ b/contract/p/gnoswap/uint256/bits_table.gno @@ -0,0 +1,79 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Code generated by go run make_tables.go. DO NOT EDIT. + +package uint256 + +const ntz8tab = "" + + "\x08\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x06\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x07\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x06\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + + "\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" + +const pop8tab = "" + + "\x00\x01\x01\x02\x01\x02\x02\x03\x01\x02\x02\x03\x02\x03\x03\x04" + + "\x01\x02\x02\x03\x02\x03\x03\x04\x02\x03\x03\x04\x03\x04\x04\x05" + + "\x01\x02\x02\x03\x02\x03\x03\x04\x02\x03\x03\x04\x03\x04\x04\x05" + + "\x02\x03\x03\x04\x03\x04\x04\x05\x03\x04\x04\x05\x04\x05\x05\x06" + + "\x01\x02\x02\x03\x02\x03\x03\x04\x02\x03\x03\x04\x03\x04\x04\x05" + + "\x02\x03\x03\x04\x03\x04\x04\x05\x03\x04\x04\x05\x04\x05\x05\x06" + + "\x02\x03\x03\x04\x03\x04\x04\x05\x03\x04\x04\x05\x04\x05\x05\x06" + + "\x03\x04\x04\x05\x04\x05\x05\x06\x04\x05\x05\x06\x05\x06\x06\x07" + + "\x01\x02\x02\x03\x02\x03\x03\x04\x02\x03\x03\x04\x03\x04\x04\x05" + + "\x02\x03\x03\x04\x03\x04\x04\x05\x03\x04\x04\x05\x04\x05\x05\x06" + + "\x02\x03\x03\x04\x03\x04\x04\x05\x03\x04\x04\x05\x04\x05\x05\x06" + + "\x03\x04\x04\x05\x04\x05\x05\x06\x04\x05\x05\x06\x05\x06\x06\x07" + + "\x02\x03\x03\x04\x03\x04\x04\x05\x03\x04\x04\x05\x04\x05\x05\x06" + + "\x03\x04\x04\x05\x04\x05\x05\x06\x04\x05\x05\x06\x05\x06\x06\x07" + + "\x03\x04\x04\x05\x04\x05\x05\x06\x04\x05\x05\x06\x05\x06\x06\x07" + + "\x04\x05\x05\x06\x05\x06\x06\x07\x05\x06\x06\x07\x06\x07\x07\x08" + +const rev8tab = "" + + "\x00\x80\x40\xc0\x20\xa0\x60\xe0\x10\x90\x50\xd0\x30\xb0\x70\xf0" + + "\x08\x88\x48\xc8\x28\xa8\x68\xe8\x18\x98\x58\xd8\x38\xb8\x78\xf8" + + "\x04\x84\x44\xc4\x24\xa4\x64\xe4\x14\x94\x54\xd4\x34\xb4\x74\xf4" + + "\x0c\x8c\x4c\xcc\x2c\xac\x6c\xec\x1c\x9c\x5c\xdc\x3c\xbc\x7c\xfc" + + "\x02\x82\x42\xc2\x22\xa2\x62\xe2\x12\x92\x52\xd2\x32\xb2\x72\xf2" + + "\x0a\x8a\x4a\xca\x2a\xaa\x6a\xea\x1a\x9a\x5a\xda\x3a\xba\x7a\xfa" + + "\x06\x86\x46\xc6\x26\xa6\x66\xe6\x16\x96\x56\xd6\x36\xb6\x76\xf6" + + "\x0e\x8e\x4e\xce\x2e\xae\x6e\xee\x1e\x9e\x5e\xde\x3e\xbe\x7e\xfe" + + "\x01\x81\x41\xc1\x21\xa1\x61\xe1\x11\x91\x51\xd1\x31\xb1\x71\xf1" + + "\x09\x89\x49\xc9\x29\xa9\x69\xe9\x19\x99\x59\xd9\x39\xb9\x79\xf9" + + "\x05\x85\x45\xc5\x25\xa5\x65\xe5\x15\x95\x55\xd5\x35\xb5\x75\xf5" + + "\x0d\x8d\x4d\xcd\x2d\xad\x6d\xed\x1d\x9d\x5d\xdd\x3d\xbd\x7d\xfd" + + "\x03\x83\x43\xc3\x23\xa3\x63\xe3\x13\x93\x53\xd3\x33\xb3\x73\xf3" + + "\x0b\x8b\x4b\xcb\x2b\xab\x6b\xeb\x1b\x9b\x5b\xdb\x3b\xbb\x7b\xfb" + + "\x07\x87\x47\xc7\x27\xa7\x67\xe7\x17\x97\x57\xd7\x37\xb7\x77\xf7" + + "\x0f\x8f\x4f\xcf\x2f\xaf\x6f\xef\x1f\x9f\x5f\xdf\x3f\xbf\x7f\xff" + +const len8tab = "" + + "\x00\x01\x02\x02\x03\x03\x03\x03\x04\x04\x04\x04\x04\x04\x04\x04" + + "\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05" + + "\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06" + + "\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06" + + "\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" + + "\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" + + "\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" + + "\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" + + "\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" diff --git a/contract/p/gnoswap/uint256/bitwise.gno b/contract/p/gnoswap/uint256/bitwise.gno new file mode 100644 index 000000000..fc6e098dc --- /dev/null +++ b/contract/p/gnoswap/uint256/bitwise.gno @@ -0,0 +1,264 @@ +// bitwise contains bitwise operations for Uint instances. +// This file includes functions to perform bitwise AND, OR, XOR, and NOT operations, as well as bit shifting. +// These operations are crucial for manipulating individual bits within a 256-bit unsigned integer. +package uint256 + +// Or sets z = x | y and returns z. +func (z *Uint) Or(x, y *Uint) *Uint { + z.arr[0] = x.arr[0] | y.arr[0] + z.arr[1] = x.arr[1] | y.arr[1] + z.arr[2] = x.arr[2] | y.arr[2] + z.arr[3] = x.arr[3] | y.arr[3] + return z +} + +// And sets z = x & y and returns z. +func (z *Uint) And(x, y *Uint) *Uint { + z.arr[0] = x.arr[0] & y.arr[0] + z.arr[1] = x.arr[1] & y.arr[1] + z.arr[2] = x.arr[2] & y.arr[2] + z.arr[3] = x.arr[3] & y.arr[3] + return z +} + +// Not sets z = ^x and returns z. +func (z *Uint) Not(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = ^x.arr[3], ^x.arr[2], ^x.arr[1], ^x.arr[0] + return z +} + +// AndNot sets z = x &^ y and returns z. +func (z *Uint) AndNot(x, y *Uint) *Uint { + z.arr[0] = x.arr[0] &^ y.arr[0] + z.arr[1] = x.arr[1] &^ y.arr[1] + z.arr[2] = x.arr[2] &^ y.arr[2] + z.arr[3] = x.arr[3] &^ y.arr[3] + return z +} + +// Xor sets z = x ^ y and returns z. +func (z *Uint) Xor(x, y *Uint) *Uint { + z.arr[0] = x.arr[0] ^ y.arr[0] + z.arr[1] = x.arr[1] ^ y.arr[1] + z.arr[2] = x.arr[2] ^ y.arr[2] + z.arr[3] = x.arr[3] ^ y.arr[3] + return z +} + +// Lsh sets z = x << n and returns z. +func (z *Uint) Lsh(x *Uint, n uint) *Uint { + // n % 64 == 0 + if n&0x3f == 0 { + switch n { + case 0: + return z.Set(x) + case 64: + return z.lsh64(x) + case 128: + return z.lsh128(x) + case 192: + return z.lsh192(x) + default: + return z.Clear() + } + } + var a, b uint64 + // Big swaps first + switch { + case n > 192: + if n > 256 { + return z.Clear() + } + z.lsh192(x) + n -= 192 + goto sh192 + case n > 128: + z.lsh128(x) + n -= 128 + goto sh128 + case n > 64: + z.lsh64(x) + n -= 64 + goto sh64 + default: + z.Set(x) + } + + // remaining shifts + a = z.arr[0] >> (64 - n) + z.arr[0] = z.arr[0] << n + +sh64: + b = z.arr[1] >> (64 - n) + z.arr[1] = (z.arr[1] << n) | a + +sh128: + a = z.arr[2] >> (64 - n) + z.arr[2] = (z.arr[2] << n) | b + +sh192: + z.arr[3] = (z.arr[3] << n) | a + + return z +} + +// Rsh sets z = x >> n and returns z. +func (z *Uint) Rsh(x *Uint, n uint) *Uint { + // n % 64 == 0 + if n&0x3f == 0 { + switch n { + case 0: + return z.Set(x) + case 64: + return z.rsh64(x) + case 128: + return z.rsh128(x) + case 192: + return z.rsh192(x) + default: + return z.Clear() + } + } + var a, b uint64 + // Big swaps first + switch { + case n > 192: + if n > 256 { + return z.Clear() + } + z.rsh192(x) + n -= 192 + goto sh192 + case n > 128: + z.rsh128(x) + n -= 128 + goto sh128 + case n > 64: + z.rsh64(x) + n -= 64 + goto sh64 + default: + z.Set(x) + } + + // remaining shifts + a = z.arr[3] << (64 - n) + z.arr[3] = z.arr[3] >> n + +sh64: + b = z.arr[2] << (64 - n) + z.arr[2] = (z.arr[2] >> n) | a + +sh128: + a = z.arr[1] << (64 - n) + z.arr[1] = (z.arr[1] >> n) | b + +sh192: + z.arr[0] = (z.arr[0] >> n) | a + + return z +} + +// SRsh (Signed/Arithmetic right shift) +// considers z to be a signed integer, during right-shift +// and sets z = x >> n and returns z. +func (z *Uint) SRsh(x *Uint, n uint) *Uint { + // If the MSB is 0, SRsh is same as Rsh. + if !x.isBitSet(255) { + return z.Rsh(x, n) + } + if n%64 == 0 { + switch n { + case 0: + return z.Set(x) + case 64: + return z.srsh64(x) + case 128: + return z.srsh128(x) + case 192: + return z.srsh192(x) + default: + return z.SetAllOne() + } + } + var a uint64 = MaxUint64 << (64 - n%64) + // Big swaps first + switch { + case n > 192: + if n > 256 { + return z.SetAllOne() + } + z.srsh192(x) + n -= 192 + goto sh192 + case n > 128: + z.srsh128(x) + n -= 128 + goto sh128 + case n > 64: + z.srsh64(x) + n -= 64 + goto sh64 + default: + z.Set(x) + } + + // remaining shifts + z.arr[3], a = (z.arr[3]>>n)|a, z.arr[3]<<(64-n) + +sh64: + z.arr[2], a = (z.arr[2]>>n)|a, z.arr[2]<<(64-n) + +sh128: + z.arr[1], a = (z.arr[1]>>n)|a, z.arr[1]<<(64-n) + +sh192: + z.arr[0] = (z.arr[0] >> n) | a + + return z +} + +func (z *Uint) lsh64(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[2], x.arr[1], x.arr[0], 0 + return z +} + +func (z *Uint) lsh128(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[1], x.arr[0], 0, 0 + return z +} + +func (z *Uint) lsh192(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = x.arr[0], 0, 0, 0 + return z +} + +func (z *Uint) rsh64(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, x.arr[3], x.arr[2], x.arr[1] + return z +} + +func (z *Uint) rsh128(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, x.arr[3], x.arr[2] + return z +} + +func (z *Uint) rsh192(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, x.arr[3] + return z +} + +func (z *Uint) srsh64(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, x.arr[3], x.arr[2], x.arr[1] + return z +} + +func (z *Uint) srsh128(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, MaxUint64, x.arr[3], x.arr[2] + return z +} + +func (z *Uint) srsh192(x *Uint) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, MaxUint64, MaxUint64, x.arr[3] + return z +} diff --git a/contract/p/gnoswap/uint256/bitwise_test.gno b/contract/p/gnoswap/uint256/bitwise_test.gno new file mode 100644 index 000000000..4437e7d5c --- /dev/null +++ b/contract/p/gnoswap/uint256/bitwise_test.gno @@ -0,0 +1,346 @@ +package uint256 + +import ( + "testing" +) + +type logicOpTest struct { + name string + x Uint + y Uint + want Uint +} + +func TestOr(t *testing.T) { + tests := []logicOpTest{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "one operand all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).Or(&tc.x, &tc.y) + if *res != tc.want { + t.Errorf("Or(%s, %s) = %s, want %s", tc.x.ToString(), tc.y.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestAnd(t *testing.T) { + tests := []logicOpTest{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed 2", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed 3", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "one operand zero", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "one operand all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + want: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).And(&tc.x, &tc.y) + if *res != tc.want { + t.Errorf("And(%s, %s) = %s, want %s", tc.x.ToString(), tc.y.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestNot(t *testing.T) { + tests := []struct { + name string + x Uint + want Uint + }{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).Not(&tc.x) + if *res != tc.want { + t.Errorf("Not(%s) = %s, want %s", tc.x.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestAndNot(t *testing.T) { + tests := []logicOpTest{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed 2", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed 3", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + }, + { + name: "one operand zero", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "one operand all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + want: Uint{arr: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0x0000000000000000, ^uint64(0)}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).AndNot(&tc.x, &tc.y) + if *res != tc.want { + t.Errorf("AndNot(%s, %s) = %s, want %s", tc.x.ToString(), tc.y.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestXor(t *testing.T) { + tests := []logicOpTest{ + { + name: "all zeros", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{0, 0, 0, 0}}, + }, + { + name: "mixed", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed 2", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "mixed 3", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), 0, 0}}, + y: Uint{arr: [4]uint64{0, 0, ^uint64(0), ^uint64(0)}}, + want: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + }, + { + name: "one operand zero", + x: Uint{arr: [4]uint64{0, 0, 0, 0}}, + y: Uint{arr: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}}, + want: Uint{arr: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}}, + }, + { + name: "one operand all ones", + x: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + y: Uint{arr: [4]uint64{0x5555555555555555, 0xAAAAAAAAAAAAAAAA, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000}}, + want: Uint{arr: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0x0000000000000000, ^uint64(0)}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res := new(Uint).Xor(&tc.x, &tc.y) + if *res != tc.want { + t.Errorf("Xor(%s, %s) = %s, want %s", tc.x.ToString(), tc.y.ToString(), res.ToString(), (tc.want).ToString()) + } + }) + } +} + +func TestLsh(t *testing.T) { + tests := []struct { + x string + y uint + want string + }{ + {"0", 0, "0"}, + {"0", 1, "0"}, + {"0", 64, "0"}, + {"1", 0, "1"}, + {"1", 1, "2"}, + {"1", 64, "18446744073709551616"}, + {"1", 128, "340282366920938463463374607431768211456"}, + {"1", 192, "6277101735386680763835789423207666416102355444464034512896"}, + {"1", 255, "57896044618658097711785492504343953926634992332820282019728792003956564819968"}, + {"1", 256, "0"}, + {"31337", 0, "31337"}, + {"31337", 1, "62674"}, + {"31337", 64, "578065619037836218990592"}, + {"31337", 128, "10663428532201448629551770073089320442396672"}, + {"31337", 192, "196705537081812415096322133155058642481399512563169449530621952"}, + {"31337", 193, "393411074163624830192644266310117284962799025126338899061243904"}, + {"31337", 255, "57896044618658097711785492504343953926634992332820282019728792003956564819968"}, + {"31337", 256, "0"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Lsh(x, tc.y) + + if got.Neq(want) { + t.Errorf("Lsh(%s, %d) = %s, want %s", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} + +func TestRsh(t *testing.T) { + tests := []struct { + x string + y uint + want string + }{ + {"0", 0, "0"}, + {"0", 1, "0"}, + {"0", 64, "0"}, + {"1", 0, "1"}, + {"1", 1, "0"}, + {"1", 64, "0"}, + {"1", 128, "0"}, + {"1", 192, "0"}, + {"1", 255, "0"}, + {"57896044618658097711785492504343953926634992332820282019728792003956564819968", 255, "1"}, + {"6277101735386680763835789423207666416102355444464034512896", 192, "1"}, + {"340282366920938463463374607431768211456", 128, "1"}, + {"18446744073709551616", 64, "1"}, + {"393411074163624830192644266310117284962799025126338899061243904", 193, "31337"}, + {"196705537081812415096322133155058642481399512563169449530621952", 192, "31337"}, + {"10663428532201448629551770073089320442396672", 128, "31337"}, + {"578065619037836218990592", 64, "31337"}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + want, err := FromDecimal(tc.want) + if err != nil { + t.Error(err) + continue + } + + got := &Uint{} + got.Rsh(x, tc.y) + + if got.Neq(want) { + t.Errorf("Rsh(%s, %d) = %s, want %s", tc.x, tc.y, got.ToString(), want.ToString()) + } + } +} diff --git a/contract/p/gnoswap/uint256/cmp.gno b/contract/p/gnoswap/uint256/cmp.gno new file mode 100644 index 000000000..a43a31e5f --- /dev/null +++ b/contract/p/gnoswap/uint256/cmp.gno @@ -0,0 +1,125 @@ +// cmp (or, comparisons) includes methods for comparing Uint instances. +// These comparison functions cover a range of operations including equality checks, less than/greater than +// evaluations, and specialized comparisons such as signed greater than. These are fundamental for logical +// decision making based on Uint values. +package uint256 + +import ( + "math/bits" +) + +// Cmp compares z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Uint) Cmp(x *Uint) (r int) { + // z < x <=> z - x < 0 i.e. when subtraction overflows. + d0, carry := bits.Sub64(z.arr[0], x.arr[0], 0) + d1, carry := bits.Sub64(z.arr[1], x.arr[1], carry) + d2, carry := bits.Sub64(z.arr[2], x.arr[2], carry) + d3, carry := bits.Sub64(z.arr[3], x.arr[3], carry) + if carry == 1 { + return -1 + } + if d0|d1|d2|d3 == 0 { + return 0 + } + return 1 +} + +// IsZero returns true if z == 0 +func (z *Uint) IsZero() bool { + return (z.arr[0] | z.arr[1] | z.arr[2] | z.arr[3]) == 0 +} + +// Sign returns: +// +// -1 if z < 0 +// 0 if z == 0 +// +1 if z > 0 +// +// Where z is interpreted as a two's complement signed number +func (z *Uint) Sign() int { + if z.IsZero() { + return 0 + } + if z.arr[3] < 0x8000000000000000 { + return 1 + } + return -1 +} + +// LtUint64 returns true if z is smaller than n +func (z *Uint) LtUint64(n uint64) bool { + return z.arr[0] < n && (z.arr[1]|z.arr[2]|z.arr[3]) == 0 +} + +// GtUint64 returns true if z is larger than n +func (z *Uint) GtUint64(n uint64) bool { + return z.arr[0] > n || (z.arr[1]|z.arr[2]|z.arr[3]) != 0 +} + +// Lt returns true if z < x +func (z *Uint) Lt(x *Uint) bool { + // z < x <=> z - x < 0 i.e. when subtraction overflows. + _, carry := bits.Sub64(z.arr[0], x.arr[0], 0) + _, carry = bits.Sub64(z.arr[1], x.arr[1], carry) + _, carry = bits.Sub64(z.arr[2], x.arr[2], carry) + _, carry = bits.Sub64(z.arr[3], x.arr[3], carry) + + return carry != 0 +} + +// Gt returns true if z > x +func (z *Uint) Gt(x *Uint) bool { + return x.Lt(z) +} + +// Lte returns true if z <= x +func (z *Uint) Lte(x *Uint) bool { + cond1 := z.Lt(x) + cond2 := z.Eq(x) + + if cond1 || cond2 { + return true + } + return false +} + +// Gte returns true if z >= x +func (z *Uint) Gte(x *Uint) bool { + cond1 := z.Gt(x) + cond2 := z.Eq(x) + + if cond1 || cond2 { + return true + } + return false +} + +// Eq returns true if z == x +func (z *Uint) Eq(x *Uint) bool { + return (z.arr[0] == x.arr[0]) && (z.arr[1] == x.arr[1]) && (z.arr[2] == x.arr[2]) && (z.arr[3] == x.arr[3]) +} + +// Neq returns true if z != x +func (z *Uint) Neq(x *Uint) bool { + return !z.Eq(x) +} + +// Sgt interprets z and x as signed integers, and returns +// true if z > x +func (z *Uint) Sgt(x *Uint) bool { + zSign := z.Sign() + xSign := x.Sign() + + switch { + case zSign >= 0 && xSign < 0: + return true + case zSign < 0 && xSign >= 0: + return false + default: + return z.Gt(x) + } +} diff --git a/contract/p/gnoswap/uint256/cmp_test.gno b/contract/p/gnoswap/uint256/cmp_test.gno new file mode 100644 index 000000000..930079f70 --- /dev/null +++ b/contract/p/gnoswap/uint256/cmp_test.gno @@ -0,0 +1,163 @@ +package uint256 + +import ( + "strings" + "testing" +) + +func TestCmp(t *testing.T) { + tests := []struct { + x, y string + want int + }{ + {"0", "0", 0}, + {"0", "1", -1}, + {"1", "0", 1}, + {"1", "1", 0}, + {"10", "10", 0}, + {"10", "11", -1}, + {"11", "10", 1}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Cmp(y) + if got != tc.want { + t.Errorf("Cmp(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestIsZero(t *testing.T) { + tests := []struct { + x string + want bool + }{ + {"0", true}, + {"1", false}, + {"10", false}, + } + + for _, tc := range tests { + x, err := FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + + got := x.IsZero() + if got != tc.want { + t.Errorf("IsZero(%s) = %v, want %v", tc.x, got, tc.want) + } + } +} + +func TestLtUint64(t *testing.T) { + tests := []struct { + x string + y uint64 + want bool + }{ + {"0", 1, true}, + {"1", 0, false}, + {"10", 10, false}, + {"0xffffffffffffffff", 0, false}, + {"0x10000000000000000", 10000000000000000, false}, + } + + for _, tc := range tests { + var x *Uint + var err error + + if strings.HasPrefix(tc.x, "0x") { + x, err = FromHex(tc.x) + if err != nil { + t.Error(err) + continue + } + } else { + x, err = FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + } + + got := x.LtUint64(tc.y) + + if got != tc.want { + t.Errorf("LtUint64(%s, %d) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} + +func TestSGT(t *testing.T) { + x := MustFromHex("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe") + y := MustFromHex("0x0") + actual := x.Sgt(y) + if actual { + t.Fatalf("Expected %v false", actual) + } + + x = MustFromHex("0x0") + y = MustFromHex("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe") + actual = x.Sgt(y) + if !actual { + t.Fatalf("Expected %v true", actual) + } +} + +func TestEq(t *testing.T) { + tests := []struct { + x string + y string + want bool + }{ + {"0xffffffffffffffff", "18446744073709551615", true}, + {"0x10000000000000000", "18446744073709551616", true}, + {"0", "0", true}, + {"115792089237316195423570985008687907853269984665640564039457584007913129639935", "115792089237316195423570985008687907853269984665640564039457584007913129639935", true}, + } + + for i, tc := range tests { + var x *Uint + var err error + + if strings.HasPrefix(tc.x, "0x") { + x, err = FromHex(tc.x) + if err != nil { + t.Error(err) + continue + } + } else { + x, err = FromDecimal(tc.x) + if err != nil { + t.Error(err) + continue + } + } + + y, err := FromDecimal(tc.y) + if err != nil { + t.Error(err) + continue + } + + got := x.Eq(y) + + if got != tc.want { + t.Errorf("Eq(%s, %s) = %v, want %v", tc.x, tc.y, got, tc.want) + } + } +} diff --git a/contract/p/gnoswap/uint256/conversion.gno b/contract/p/gnoswap/uint256/conversion.gno new file mode 100644 index 000000000..4ef90602a --- /dev/null +++ b/contract/p/gnoswap/uint256/conversion.gno @@ -0,0 +1,570 @@ +// conversions contains methods for converting Uint instances to other types and vice versa. +// This includes conversions to and from basic types such as uint64 and int32, as well as string representations +// and byte slices. Additionally, it covers marshaling and unmarshaling for JSON and other text formats. +package uint256 + +import ( + "encoding/binary" + "errors" + "strconv" + "strings" +) + +// Uint64 returns the lower 64-bits of z +func (z *Uint) Uint64() uint64 { + return z.arr[0] +} + +// Uint64WithOverflow returns the lower 64-bits of z and bool whether overflow occurred +func (z *Uint) Uint64WithOverflow() (uint64, bool) { + return z.arr[0], (z.arr[1] | z.arr[2] | z.arr[3]) != 0 +} + +// SetUint64 sets z to the value x +func (z *Uint) SetUint64(x uint64) *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, x + return z +} + +// IsUint64 reports whether z can be represented as a uint64. +func (z *Uint) IsUint64() bool { + return (z.arr[1] | z.arr[2] | z.arr[3]) == 0 +} + +// Dec returns the decimal representation of z. +func (z *Uint) Dec() string { + if z.IsZero() { + return "0" + } + if z.IsUint64() { + return strconv.FormatUint(z.Uint64(), 10) + } + + // The max uint64 value being 18446744073709551615, the largest + // power-of-ten below that is 10000000000000000000. + // When we do a DivMod using that number, the remainder that we + // get back is the lower part of the output. + // + // The ascii-output of remainder will never exceed 19 bytes (since it will be + // below 10000000000000000000). + // + // Algorithm example using 100 as divisor + // + // 12345 % 100 = 45 (rem) + // 12345 / 100 = 123 (quo) + // -> output '45', continue iterate on 123 + var ( + // out is 98 bytes long: 78 (max size of a string without leading zeroes, + // plus slack so we can copy 19 bytes every iteration). + // We init it with zeroes, because when strconv appends the ascii representations, + // it will omit leading zeroes. + out = []byte("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + divisor = NewUint(10000000000000000000) // 20 digits + y = new(Uint).Set(z) // copy to avoid modifying z + pos = len(out) // position to write to + buf = make([]byte, 0, 19) // buffer to write uint64:s to + ) + for { + // Obtain Q and R for divisor + var quot Uint + rem := udivrem(quot.arr[:], y.arr[:], divisor) + y.Set(") // Set Q for next loop + // Convert the R to ascii representation + buf = strconv.AppendUint(buf[:0], rem.Uint64(), 10) + // Copy in the ascii digits + copy(out[pos-len(buf):], buf) + if y.IsZero() { + break + } + // Move 19 digits left + pos -= 19 + } + // skip leading zeroes by only using the 'used size' of buf + return string(out[pos-len(buf):]) +} + +func (z *Uint) Scan(src interface{}) error { + if src == nil { + z.Clear() + return nil + } + + switch src := src.(type) { + case string: + return z.scanScientificFromString(src) + case []byte: + return z.scanScientificFromString(string(src)) + } + return errors.New("default // unsupported type: can't convert to uint256.Uint") +} + +func (z *Uint) scanScientificFromString(src string) error { + if len(src) == 0 { + z.Clear() + return nil + } + + idx := strings.IndexByte(src, 'e') + if idx == -1 { + return z.SetFromDecimal(src) + } + if err := z.SetFromDecimal(src[:idx]); err != nil { + return err + } + if src[(idx+1):] == "0" { + return nil + } + exp := new(Uint) + if err := exp.SetFromDecimal(src[(idx + 1):]); err != nil { + return err + } + if exp.GtUint64(77) { // 10**78 is larger than 2**256 + return ErrBig256Range + } + exp.Exp(NewUint(10), exp) + if _, overflow := z.MulOverflow(z, exp); overflow { + return ErrBig256Range + } + return nil +} + +// ToString returns the decimal string representation of z. It returns an empty string if z is nil. +// OBS: doesn't exist from holiman's uint256 +func (z *Uint) ToString() string { + if z == nil { + return "" + } + + return z.Dec() +} + +// MarshalJSON implements json.Marshaler. +// MarshalJSON marshals using the 'decimal string' representation. This is _not_ compatible +// with big.Uint: big.Uint marshals into JSON 'native' numeric format. +// +// The JSON native format is, on some platforms, (e.g. javascript), limited to 53-bit large +// integer space. Thus, U256 uses string-format, which is not compatible with +// big.int (big.Uint refuses to unmarshal a string representation). +func (z *Uint) MarshalJSON() ([]byte, error) { + return []byte(`"` + z.Dec() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler. UnmarshalJSON accepts either +// - Quoted string: either hexadecimal OR decimal +// - Not quoted string: only decimal +func (z *Uint) UnmarshalJSON(input []byte) error { + if len(input) < 2 || input[0] != '"' || input[len(input)-1] != '"' { + // if not quoted, it must be decimal + return z.fromDecimal(string(input)) + } + return z.UnmarshalText(input[1 : len(input)-1]) +} + +// MarshalText implements encoding.TextMarshaler +// MarshalText marshals using the decimal representation (compatible with big.Uint) +func (z *Uint) MarshalText() ([]byte, error) { + return []byte(z.Dec()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. This method +// can unmarshal either hexadecimal or decimal. +// - For hexadecimal, the input _must_ be prefixed with 0x or 0X +func (z *Uint) UnmarshalText(input []byte) error { + if len(input) >= 2 && input[0] == '0' && (input[1] == 'x' || input[1] == 'X') { + return z.fromHex(string(input)) + } + return z.fromDecimal(string(input)) +} + +// SetBytes interprets buf as the bytes of a big-endian unsigned +// integer, sets z to that value, and returns z. +// If buf is larger than 32 bytes, the last 32 bytes is used. +func (z *Uint) SetBytes(buf []byte) *Uint { + switch l := len(buf); l { + case 0: + z.Clear() + case 1: + z.SetBytes1(buf) + case 2: + z.SetBytes2(buf) + case 3: + z.SetBytes3(buf) + case 4: + z.SetBytes4(buf) + case 5: + z.SetBytes5(buf) + case 6: + z.SetBytes6(buf) + case 7: + z.SetBytes7(buf) + case 8: + z.SetBytes8(buf) + case 9: + z.SetBytes9(buf) + case 10: + z.SetBytes10(buf) + case 11: + z.SetBytes11(buf) + case 12: + z.SetBytes12(buf) + case 13: + z.SetBytes13(buf) + case 14: + z.SetBytes14(buf) + case 15: + z.SetBytes15(buf) + case 16: + z.SetBytes16(buf) + case 17: + z.SetBytes17(buf) + case 18: + z.SetBytes18(buf) + case 19: + z.SetBytes19(buf) + case 20: + z.SetBytes20(buf) + case 21: + z.SetBytes21(buf) + case 22: + z.SetBytes22(buf) + case 23: + z.SetBytes23(buf) + case 24: + z.SetBytes24(buf) + case 25: + z.SetBytes25(buf) + case 26: + z.SetBytes26(buf) + case 27: + z.SetBytes27(buf) + case 28: + z.SetBytes28(buf) + case 29: + z.SetBytes29(buf) + case 30: + z.SetBytes30(buf) + case 31: + z.SetBytes31(buf) + default: + z.SetBytes32(buf[l-32:]) + } + return z +} + +// SetBytes1 is identical to SetBytes(in[:1]), but panics is input is too short +func (z *Uint) SetBytes1(in []byte) *Uint { + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = uint64(in[0]) + return z +} + +// SetBytes2 is identical to SetBytes(in[:2]), but panics is input is too short +func (z *Uint) SetBytes2(in []byte) *Uint { + _ = in[1] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = uint64(binary.BigEndian.Uint16(in[0:2])) + return z +} + +// SetBytes3 is identical to SetBytes(in[:3]), but panics is input is too short +func (z *Uint) SetBytes3(in []byte) *Uint { + _ = in[2] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = uint64(binary.BigEndian.Uint16(in[1:3])) | uint64(in[0])<<16 + return z +} + +// SetBytes4 is identical to SetBytes(in[:4]), but panics is input is too short +func (z *Uint) SetBytes4(in []byte) *Uint { + _ = in[3] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = uint64(binary.BigEndian.Uint32(in[0:4])) + return z +} + +// SetBytes5 is identical to SetBytes(in[:5]), but panics is input is too short +func (z *Uint) SetBytes5(in []byte) *Uint { + _ = in[4] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = bigEndianUint40(in[0:5]) + return z +} + +// SetBytes6 is identical to SetBytes(in[:6]), but panics is input is too short +func (z *Uint) SetBytes6(in []byte) *Uint { + _ = in[5] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = bigEndianUint48(in[0:6]) + return z +} + +// SetBytes7 is identical to SetBytes(in[:7]), but panics is input is too short +func (z *Uint) SetBytes7(in []byte) *Uint { + _ = in[6] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = bigEndianUint56(in[0:7]) + return z +} + +// SetBytes8 is identical to SetBytes(in[:8]), but panics is input is too short +func (z *Uint) SetBytes8(in []byte) *Uint { + _ = in[7] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + z.arr[0] = binary.BigEndian.Uint64(in[0:8]) + return z +} + +// SetBytes9 is identical to SetBytes(in[:9]), but panics is input is too short +func (z *Uint) SetBytes9(in []byte) *Uint { + _ = in[8] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = uint64(in[0]) + z.arr[0] = binary.BigEndian.Uint64(in[1:9]) + return z +} + +// SetBytes10 is identical to SetBytes(in[:10]), but panics is input is too short +func (z *Uint) SetBytes10(in []byte) *Uint { + _ = in[9] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = uint64(binary.BigEndian.Uint16(in[0:2])) + z.arr[0] = binary.BigEndian.Uint64(in[2:10]) + return z +} + +// SetBytes11 is identical to SetBytes(in[:11]), but panics is input is too short +func (z *Uint) SetBytes11(in []byte) *Uint { + _ = in[10] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = uint64(binary.BigEndian.Uint16(in[1:3])) | uint64(in[0])<<16 + z.arr[0] = binary.BigEndian.Uint64(in[3:11]) + return z +} + +// SetBytes12 is identical to SetBytes(in[:12]), but panics is input is too short +func (z *Uint) SetBytes12(in []byte) *Uint { + _ = in[11] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = uint64(binary.BigEndian.Uint32(in[0:4])) + z.arr[0] = binary.BigEndian.Uint64(in[4:12]) + return z +} + +// SetBytes13 is identical to SetBytes(in[:13]), but panics is input is too short +func (z *Uint) SetBytes13(in []byte) *Uint { + _ = in[12] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = bigEndianUint40(in[0:5]) + z.arr[0] = binary.BigEndian.Uint64(in[5:13]) + return z +} + +// SetBytes14 is identical to SetBytes(in[:14]), but panics is input is too short +func (z *Uint) SetBytes14(in []byte) *Uint { + _ = in[13] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = bigEndianUint48(in[0:6]) + z.arr[0] = binary.BigEndian.Uint64(in[6:14]) + return z +} + +// SetBytes15 is identical to SetBytes(in[:15]), but panics is input is too short +func (z *Uint) SetBytes15(in []byte) *Uint { + _ = in[14] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = bigEndianUint56(in[0:7]) + z.arr[0] = binary.BigEndian.Uint64(in[7:15]) + return z +} + +// SetBytes16 is identical to SetBytes(in[:16]), but panics is input is too short +func (z *Uint) SetBytes16(in []byte) *Uint { + _ = in[15] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3], z.arr[2] = 0, 0 + z.arr[1] = binary.BigEndian.Uint64(in[0:8]) + z.arr[0] = binary.BigEndian.Uint64(in[8:16]) + return z +} + +// SetBytes17 is identical to SetBytes(in[:17]), but panics is input is too short +func (z *Uint) SetBytes17(in []byte) *Uint { + _ = in[16] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = uint64(in[0]) + z.arr[1] = binary.BigEndian.Uint64(in[1:9]) + z.arr[0] = binary.BigEndian.Uint64(in[9:17]) + return z +} + +// SetBytes18 is identical to SetBytes(in[:18]), but panics is input is too short +func (z *Uint) SetBytes18(in []byte) *Uint { + _ = in[17] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = uint64(binary.BigEndian.Uint16(in[0:2])) + z.arr[1] = binary.BigEndian.Uint64(in[2:10]) + z.arr[0] = binary.BigEndian.Uint64(in[10:18]) + return z +} + +// SetBytes19 is identical to SetBytes(in[:19]), but panics is input is too short +func (z *Uint) SetBytes19(in []byte) *Uint { + _ = in[18] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = uint64(binary.BigEndian.Uint16(in[1:3])) | uint64(in[0])<<16 + z.arr[1] = binary.BigEndian.Uint64(in[3:11]) + z.arr[0] = binary.BigEndian.Uint64(in[11:19]) + return z +} + +// SetBytes20 is identical to SetBytes(in[:20]), but panics is input is too short +func (z *Uint) SetBytes20(in []byte) *Uint { + _ = in[19] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = uint64(binary.BigEndian.Uint32(in[0:4])) + z.arr[1] = binary.BigEndian.Uint64(in[4:12]) + z.arr[0] = binary.BigEndian.Uint64(in[12:20]) + return z +} + +// SetBytes21 is identical to SetBytes(in[:21]), but panics is input is too short +func (z *Uint) SetBytes21(in []byte) *Uint { + _ = in[20] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = bigEndianUint40(in[0:5]) + z.arr[1] = binary.BigEndian.Uint64(in[5:13]) + z.arr[0] = binary.BigEndian.Uint64(in[13:21]) + return z +} + +// SetBytes22 is identical to SetBytes(in[:22]), but panics is input is too short +func (z *Uint) SetBytes22(in []byte) *Uint { + _ = in[21] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = bigEndianUint48(in[0:6]) + z.arr[1] = binary.BigEndian.Uint64(in[6:14]) + z.arr[0] = binary.BigEndian.Uint64(in[14:22]) + return z +} + +// SetBytes23 is identical to SetBytes(in[:23]), but panics is input is too short +func (z *Uint) SetBytes23(in []byte) *Uint { + _ = in[22] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = bigEndianUint56(in[0:7]) + z.arr[1] = binary.BigEndian.Uint64(in[7:15]) + z.arr[0] = binary.BigEndian.Uint64(in[15:23]) + return z +} + +// SetBytes24 is identical to SetBytes(in[:24]), but panics is input is too short +func (z *Uint) SetBytes24(in []byte) *Uint { + _ = in[23] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = 0 + z.arr[2] = binary.BigEndian.Uint64(in[0:8]) + z.arr[1] = binary.BigEndian.Uint64(in[8:16]) + z.arr[0] = binary.BigEndian.Uint64(in[16:24]) + return z +} + +// SetBytes25 is identical to SetBytes(in[:25]), but panics is input is too short +func (z *Uint) SetBytes25(in []byte) *Uint { + _ = in[24] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = uint64(in[0]) + z.arr[2] = binary.BigEndian.Uint64(in[1:9]) + z.arr[1] = binary.BigEndian.Uint64(in[9:17]) + z.arr[0] = binary.BigEndian.Uint64(in[17:25]) + return z +} + +// SetBytes26 is identical to SetBytes(in[:26]), but panics is input is too short +func (z *Uint) SetBytes26(in []byte) *Uint { + _ = in[25] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = uint64(binary.BigEndian.Uint16(in[0:2])) + z.arr[2] = binary.BigEndian.Uint64(in[2:10]) + z.arr[1] = binary.BigEndian.Uint64(in[10:18]) + z.arr[0] = binary.BigEndian.Uint64(in[18:26]) + return z +} + +// SetBytes27 is identical to SetBytes(in[:27]), but panics is input is too short +func (z *Uint) SetBytes27(in []byte) *Uint { + _ = in[26] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = uint64(binary.BigEndian.Uint16(in[1:3])) | uint64(in[0])<<16 + z.arr[2] = binary.BigEndian.Uint64(in[3:11]) + z.arr[1] = binary.BigEndian.Uint64(in[11:19]) + z.arr[0] = binary.BigEndian.Uint64(in[19:27]) + return z +} + +// SetBytes28 is identical to SetBytes(in[:28]), but panics is input is too short +func (z *Uint) SetBytes28(in []byte) *Uint { + _ = in[27] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = uint64(binary.BigEndian.Uint32(in[0:4])) + z.arr[2] = binary.BigEndian.Uint64(in[4:12]) + z.arr[1] = binary.BigEndian.Uint64(in[12:20]) + z.arr[0] = binary.BigEndian.Uint64(in[20:28]) + return z +} + +// SetBytes29 is identical to SetBytes(in[:29]), but panics is input is too short +func (z *Uint) SetBytes29(in []byte) *Uint { + _ = in[23] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = bigEndianUint40(in[0:5]) + z.arr[2] = binary.BigEndian.Uint64(in[5:13]) + z.arr[1] = binary.BigEndian.Uint64(in[13:21]) + z.arr[0] = binary.BigEndian.Uint64(in[21:29]) + return z +} + +// SetBytes30 is identical to SetBytes(in[:30]), but panics is input is too short +func (z *Uint) SetBytes30(in []byte) *Uint { + _ = in[29] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = bigEndianUint48(in[0:6]) + z.arr[2] = binary.BigEndian.Uint64(in[6:14]) + z.arr[1] = binary.BigEndian.Uint64(in[14:22]) + z.arr[0] = binary.BigEndian.Uint64(in[22:30]) + return z +} + +// SetBytes31 is identical to SetBytes(in[:31]), but panics is input is too short +func (z *Uint) SetBytes31(in []byte) *Uint { + _ = in[30] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = bigEndianUint56(in[0:7]) + z.arr[2] = binary.BigEndian.Uint64(in[7:15]) + z.arr[1] = binary.BigEndian.Uint64(in[15:23]) + z.arr[0] = binary.BigEndian.Uint64(in[23:31]) + return z +} + +// SetBytes32 sets z to the value of the big-endian 256-bit unsigned integer in. +func (z *Uint) SetBytes32(in []byte) *Uint { + _ = in[31] // bounds check hint to compiler; see golang.org/issue/14808 + z.arr[3] = binary.BigEndian.Uint64(in[0:8]) + z.arr[2] = binary.BigEndian.Uint64(in[8:16]) + z.arr[1] = binary.BigEndian.Uint64(in[16:24]) + z.arr[0] = binary.BigEndian.Uint64(in[24:32]) + return z +} + +// Utility methods that are "missing" among the bigEndian.UintXX methods. + +// bigEndianUint40 returns the uint64 value represented by the 5 bytes in big-endian order. +func bigEndianUint40(b []byte) uint64 { + _ = b[4] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[4]) | uint64(b[3])<<8 | uint64(b[2])<<16 | uint64(b[1])<<24 | + uint64(b[0])<<32 +} + +// bigEndianUint56 returns the uint64 value represented by the 7 bytes in big-endian order. +func bigEndianUint56(b []byte) uint64 { + _ = b[6] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[6]) | uint64(b[5])<<8 | uint64(b[4])<<16 | uint64(b[3])<<24 | + uint64(b[2])<<32 | uint64(b[1])<<40 | uint64(b[0])<<48 +} + +// bigEndianUint48 returns the uint64 value represented by the 6 bytes in big-endian order. +func bigEndianUint48(b []byte) uint64 { + _ = b[5] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[5]) | uint64(b[4])<<8 | uint64(b[3])<<16 | uint64(b[2])<<24 | + uint64(b[1])<<32 | uint64(b[0])<<40 +} diff --git a/contract/p/gnoswap/uint256/conversion_test.gno b/contract/p/gnoswap/uint256/conversion_test.gno new file mode 100644 index 000000000..12ae99cc0 --- /dev/null +++ b/contract/p/gnoswap/uint256/conversion_test.gno @@ -0,0 +1,60 @@ +package uint256 + +import ( + "testing" +) + +func TestIsUint64(t *testing.T) { + tests := []struct { + x string + want bool + }{ + {"0x0", true}, + {"0x1", true}, + {"0x10", true}, + {"0xffffffffffffffff", true}, + {"0x10000000000000000", false}, + } + + for _, tc := range tests { + x := MustFromHex(tc.x) + got := x.IsUint64() + + if got != tc.want { + t.Errorf("IsUint64(%s) = %v, want %v", tc.x, got, tc.want) + } + } +} + +func TestDec(t *testing.T) { + testCases := []struct { + name string + z Uint + want string + }{ + { + name: "zero", + z: Uint{arr: [4]uint64{0, 0, 0, 0}}, + want: "0", + }, + { + name: "less than 20 digits", + z: Uint{arr: [4]uint64{1234567890, 0, 0, 0}}, + want: "1234567890", + }, + { + name: "max possible value", + z: Uint{arr: [4]uint64{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)}}, + want: "115792089237316195423570985008687907853269984665640564039457584007913129639935", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.z.Dec() + if result != tc.want { + t.Errorf("Dec(%v) = %s, want %s", tc.z, result, tc.want) + } + }) + } +} diff --git a/contract/p/gnoswap/uint256/error.gno b/contract/p/gnoswap/uint256/error.gno new file mode 100644 index 000000000..d200bb9cc --- /dev/null +++ b/contract/p/gnoswap/uint256/error.gno @@ -0,0 +1,73 @@ +package uint256 + +import ( + "errors" +) + +var ( + ErrEmptyString = errors.New("empty hex string") + ErrSyntax = errors.New("invalid hex string") + ErrRange = errors.New("number out of range") + ErrMissingPrefix = errors.New("hex string without 0x prefix") + ErrEmptyNumber = errors.New("hex string \"0x\"") + ErrLeadingZero = errors.New("hex number with leading zero digits") + ErrBig256Range = errors.New("hex number > 256 bits") + ErrBadBufferLength = errors.New("bad ssz buffer length") + ErrBadEncodedLength = errors.New("bad ssz encoded length") + ErrInvalidBase = errors.New("invalid base") + ErrInvalidBitSize = errors.New("invalid bit size") +) + +type u256Error struct { + fn string // function name + input string + err error +} + +func (e *u256Error) Error() string { + return e.fn + ": " + e.input + ": " + e.err.Error() +} + +func (e *u256Error) Unwrap() error { + return e.err +} + +func errEmptyString(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrEmptyString} +} + +func errSyntax(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrSyntax} +} + +func errMissingPrefix(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrMissingPrefix} +} + +func errEmptyNumber(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrEmptyNumber} +} + +func errLeadingZero(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrLeadingZero} +} + +func errRange(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrRange} +} + +func errBig256Range(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrBig256Range} +} + +func errBadBufferLength(fn, input string) error { + return &u256Error{fn: fn, input: input, err: ErrBadBufferLength} +} + +func errInvalidBase(fn string, base int) error { + return &u256Error{fn: fn, input: string(base), err: ErrInvalidBase} +} + +func errInvalidBitSize(fn string, bitSize int) error { + return &u256Error{fn: fn, input: string(bitSize), err: ErrInvalidBitSize} +} diff --git a/contract/p/gnoswap/uint256/fullmath.gno b/contract/p/gnoswap/uint256/fullmath.gno new file mode 100644 index 000000000..b8b079943 --- /dev/null +++ b/contract/p/gnoswap/uint256/fullmath.gno @@ -0,0 +1,151 @@ +// REF: https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/FullMath.sol +package uint256 + +import ( + "gno.land/p/demo/ufmt" +) + +const ( + MAX_UINT256 = "115792089237316195423570985008687907853269984665640564039457584007913129639935" +) + +func MulDiv( + a, b, denominator *Uint, +) *Uint { + prod0 := Zero() + prod1 := Zero() + + { + mm := new(Uint).MulMod(a, b, new(Uint).Not(Zero())) + prod0 = new(Uint).Mul(a, b) + + ltBool := mm.Lt(prod0) + ltUint := Zero() + if ltBool { + ltUint = One() + } + prod1 = new(Uint).Sub(new(Uint).Sub(mm, prod0), ltUint) + } + + // Handle non-overflow cases, 256 by 256 division + if prod1.IsZero() { + if !(denominator.Gt(Zero())) { // require(denominator > 0); + panic(ufmt.Sprintf("uint256_MulDiv()__denominator(%s) > 0", denominator.ToString())) + } + + result := new(Uint).Div(prod0, denominator) + return result + } + + // Make sure the result is less than 2**256. + // Also prevents denominator == 0 + if !(denominator.Gt(prod1)) { // require(denominator > prod1) + panic(ufmt.Sprintf("uint256_MulDiv()__denominator(%s) > prod1(%s)", denominator.ToString(), prod1.ToString())) + } + + /////////////////////////////////////////////// + // 512 by 256 division. + /////////////////////////////////////////////// + + // Make division exact by subtracting the remainder from [prod1 prod0] + // Compute remainder using mulmod + remainder := Zero() + remainder = new(Uint).MulMod(a, b, denominator) + + // Subtract 256 bit number from 512 bit number + gtBool := remainder.Gt(prod0) + gtUint := Zero() + if gtBool { + gtUint = One() + } + prod1 = new(Uint).Sub(prod1, gtUint) + prod0 = new(Uint).Sub(prod0, remainder) + + // Factor powers of two out of denominator + // Compute largest power of two divisor of denominator. + // Always >= 1. + twos := Zero() + twos = new(Uint).And(new(Uint).Neg(denominator), denominator) + + // Divide denominator by power of two + denominator = new(Uint).Div(denominator, twos) + + // Divide [prod1 prod0] by the factors of two + prod0 = new(Uint).Div(prod0, twos) + + // Shift in bits from prod1 into prod0. For this we need + // to flip `twos` such that it is 2**256 / twos. + // If twos is zero, then it becomes one + twos = new(Uint).Add( + new(Uint).Div( + new(Uint).Sub(Zero(), twos), + twos, + ), + One(), + ) + prod0 = new(Uint).Or(prod0, new(Uint).Mul(prod1, twos)) + + // Invert denominator mod 2**256 + // Now that denominator is an odd number, it has an inverse + // modulo 2**256 such that denominator * inv = 1 mod 2**256. + // Compute the inverse by starting with a seed that is correct + // correct for four bits. That is, denominator * inv = 1 mod 2**4 + inv := Zero() + inv = new(Uint).Mul(NewUint(3), denominator) + inv = new(Uint).Xor(inv, NewUint(2)) + + // Now use Newton-Raphson iteration to improve the precision. + // Thanks to Hensel's lifting lemma, this also works in modular + // arithmetic, doubling the correct bits in each step. + + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**8 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**16 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**32 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**64 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**128 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**256 + + // Because the division is now exact we can divide by multiplying + // with the modular inverse of denominator. This will give us the + // correct result modulo 2**256. Since the precoditions guarantee + // that the outcome is less than 2**256, this is the final result. + // We don't need to compute the high bits of the result and prod1 + // is no longer required. + result := new(Uint).Mul(prod0, inv) + return result +} + +func MulDivRoundingUp( + a, b, denominator *Uint, +) *Uint { + result := MulDiv(a, b, denominator) + + if new(Uint).MulMod(a, b, denominator).Gt(Zero()) { + if !(result.Lt(MustFromDecimal(MAX_UINT256))) { // require(result < MAX_UINT256) + panic(ufmt.Sprintf("uint256_MulDivRoundingUp()__result(%s) < MAX_UINT256", result.ToString())) + } + + result = new(Uint).Add(result, One()) + } + + return result +} + +// UnsafeMath +// https://github.com/Uniswap/v3-core/blob/d8b1c635c275d2a9450bd6a78f3fa2484fef73eb/contracts/libraries/UnsafeMath.sol +func DivRoundingUp( + x, y *Uint, +) *Uint { + div := new(Uint).Div(x, y) + mod := new(Uint).Mod(x, y) + + z := new(Uint).Add(div, gt(mod, Zero())) + return z +} + +func gt(x, y *Uint) *Uint { + if x.Gt(y) { + return One() + } + return Zero() +} diff --git a/contract/p/gnoswap/uint256/fullmath_test.gno b/contract/p/gnoswap/uint256/fullmath_test.gno new file mode 100644 index 000000000..71e3917e9 --- /dev/null +++ b/contract/p/gnoswap/uint256/fullmath_test.gno @@ -0,0 +1,256 @@ +package uint256 + +import ( + "testing" +) + +var Q128 *Uint + +func init() { + Q128 = MustFromDecimal("340282366920938463463374607431768211456") // 2**128 +} + +func TestFullMathMulDiv(t *testing.T) { + t.Run("reverts if denominator is 0", func(t *testing.T) { + x := NewUint(5) + y := Zero() + + shouldPanic( + t, + func() { + MulDiv(Q128, x, y) + }, + ) + }) + + t.Run("reverts if denominator is 0 and numerator overflows", func(t *testing.T) { + y := Zero() + + shouldPanic( + t, + func() { + MulDiv(Q128, Q128, y) + }, + ) + }) + + t.Run("reverts if output overflows uint256", func(t *testing.T) { + y := One() + + shouldPanic( + t, + func() { + MulDiv(Q128, Q128, y) + }, + ) + }) + + t.Run("reverts on overflow with all max inputs", func(t *testing.T) { + MaxUint256 := MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") + MaxUint256Sub1 := MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639934") + + shouldPanic( + t, + func() { + MulDiv(MaxUint256, MaxUint256, MaxUint256Sub1) + }, + ) + }) + + t.Run("all max inputs", func(t *testing.T) { + MaxUint256 := MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") + + got := MulDiv(MaxUint256, MaxUint256, MaxUint256) + + rst := got.Eq(MaxUint256) + if !rst { + t.Errorf("MaxUin256*MaxUint256/MaxUint256 is not same to MaxUint256") + } + }) + + t.Run("accurate without phantom overflow", func(t *testing.T) { + + result := new(Uint).Div(Q128, NewUint(3)) + + x := NewUint(50) + x = x.Mul(x, Q128) + x = x.Div(x, NewUint(100)) /*0.5=*/ + + y := NewUint(150) + y = y.Mul(y, Q128) + y = y.Div(y, NewUint(100)) /*1.5=*/ + + got := MulDiv(Q128, x, y) + if !got.Eq(result) { + t.Errorf("Q128/3 is not same to Q128 * (50*Q128/100) / (150*Q128/100)") + } + }) + + t.Run("accurate with phantom overflow", func(t *testing.T) { + + x := NewUint(4375) + expected := MulDiv(x, Q128, NewUint(1000)) + + y := NewUint(35) + y = y.Mul(y, Q128) + + denom := NewUint(8) + denom = denom.Mul(denom, Q128) + + got := MulDiv(Q128, y, denom) + if !got.Eq(expected) { + t.Errorf("4375*Q128/1000 is not same to Q128 * 35*Q128 / 8*Q128") + } + }) + + t.Run("accurate with phantom overflow and repeating decimal", func(t *testing.T) { + + expected := MulDiv(One(), Q128, NewUint(3)) + + y := NewUint(1000) + y = y.Mul(y, Q128) + + denom := NewUint(3000) + denom = denom.Mul(denom, Q128) + + got := MulDiv(Q128, y, denom) + if !got.Eq(expected) { + t.Errorf("1*Q128/3 is not same to Q128 * 1000*Q128 / 3000*Q128") + } + }) +} + +func TestMulDivRoundingUp(t *testing.T) { + t.Run("reverts if denominator is 0", func(t *testing.T) { + x := NewUint(5) + y := Zero() + + shouldPanic( + t, + func() { + MulDivRoundingUp(Q128, x, y) + }, + ) + }) + + t.Run("reverts if denominator is 0 and numerator overflows", func(t *testing.T) { + y := Zero() + + shouldPanic( + t, + func() { + MulDivRoundingUp(Q128, Q128, y) + }, + ) + }) + + t.Run("reverts if output overflows uint256", func(t *testing.T) { + y := One() + + shouldPanic( + t, + func() { + MulDivRoundingUp(Q128, Q128, y) + }, + ) + }) + + t.Run("reverts on overflow with all max inputs", func(t *testing.T) { + MaxUint256 := MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") + MaxUint256Sub1 := MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639934") + + shouldPanic( + t, + func() { + MulDivRoundingUp(MaxUint256, MaxUint256, MaxUint256Sub1) + }, + ) + }) + + t.Run("reverts if mulDiv overflows 256 bits after rounding up", func(t *testing.T) { + x := MustFromDecimal("535006138814359") + y := MustFromDecimal("432862656469423142931042426214547535783388063929571229938474969") + + shouldPanic( + t, + func() { + MulDivRoundingUp(x, y, NewUint(2)) + }, + ) + }) + + t.Run("reverts if mulDiv overflows 256 bits after rounding up case 2", func(t *testing.T) { + x := MustFromDecimal("115792089237316195423570985008687907853269984659341747863450311749907997002549") + y := MustFromDecimal("115792089237316195423570985008687907853269984659341747863450311749907997002550") + z := MustFromDecimal("115792089237316195423570985008687907853269984653042931687443039491902864365164") + + shouldPanic( + t, + func() { + MulDivRoundingUp(x, y, z) + }, + ) + }) + + t.Run("all max inputs", func(t *testing.T) { + MaxUint256 := MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") + + got := MulDivRoundingUp(MaxUint256, MaxUint256, MaxUint256) + + rst := got.Eq(MaxUint256) + if !rst { + t.Errorf("MaxUin256*MaxUint256/MaxUint256 RoudingUp is not same to MaxUint256") + } + }) + + t.Run("accurate without phantom overflow", func(t *testing.T) { + + result := new(Uint).Div(Q128, NewUint(3)) + result = result.Add(result, One()) + + x := NewUint(50) + x = x.Mul(x, Q128) + x = x.Div(x, NewUint(100)) /*0.5=*/ + + y := NewUint(150) + y = y.Mul(y, Q128) + y = y.Div(y, NewUint(100)) /*1.5=*/ + + got := MulDivRoundingUp(Q128, x, y) + if !got.Eq(result) { + t.Errorf("Q128/3 is not same to Q128 * (50*Q128/100) / (150*Q128/100)") + } + }) + + t.Run("accurate with phantom overflow", func(t *testing.T) { + x := NewUint(4375) + expected := MulDiv(x, Q128, NewUint(1000)) + + y := NewUint(35) + y = y.Mul(y, Q128) + + denom := NewUint(8) + denom = denom.Mul(denom, Q128) + + got := MulDivRoundingUp(Q128, y, denom) + if !got.Eq(expected) { + t.Errorf("4375*Q128/1000 is not same to Q128 * 35*Q128 / 8*Q128") + } + }) + + t.Run("accurate with phantom overflow and repeating decimal", func(t *testing.T) { + expected := MulDiv(One(), Q128, NewUint(3)) + expected = expected.Add(expected, One()) + + y := NewUint(1000) + y = y.Mul(y, Q128) + + denom := NewUint(3000) + denom = denom.Mul(denom, Q128) + + got := MulDivRoundingUp(Q128, y, denom) + if !got.Eq(expected) { + t.Errorf("1*Q128/3+1 is not same to Q128 * 1000*Q128 / 3000*Q128") + } + }) +} diff --git a/contract/p/gnoswap/uint256/gno.mod b/contract/p/gnoswap/uint256/gno.mod new file mode 100644 index 000000000..880113430 --- /dev/null +++ b/contract/p/gnoswap/uint256/gno.mod @@ -0,0 +1 @@ +module gno.land/p/gnoswap/uint256 diff --git a/contract/p/gnoswap/uint256/gs_pointer.gno b/contract/p/gnoswap/uint256/gs_pointer.gno new file mode 100644 index 000000000..e24e6bf33 --- /dev/null +++ b/contract/p/gnoswap/uint256/gs_pointer.gno @@ -0,0 +1,8 @@ +package uint256 + +func (z *Uint) NilToZero() *Uint { + if z == nil { + z = NewUint(0) + } + return z +} diff --git a/contract/p/gnoswap/uint256/mod.gno b/contract/p/gnoswap/uint256/mod.gno new file mode 100644 index 000000000..f6ff0967e --- /dev/null +++ b/contract/p/gnoswap/uint256/mod.gno @@ -0,0 +1,605 @@ +package uint256 + +import ( + "math/bits" +) + +// Some utility functions + +// Reciprocal computes a 320-bit value representing 1/m +// +// Notes: +// - specialized for m.arr[3] != 0, hence limited to 2^192 <= m < 2^256 +// - returns zero if m.arr[3] == 0 +// - starts with a 32-bit division, refines with newton-raphson iterations +func Reciprocal(m *Uint) (mu [5]uint64) { + if m.arr[3] == 0 { + return mu + } + + s := bits.LeadingZeros64(m.arr[3]) // Replace with leadingZeros(m) for general case + p := 255 - s // floor(log_2(m)), m>0 + + // 0 or a power of 2? + + // Check if at least one bit is set in m.arr[2], m.arr[1] or m.arr[0], + // or at least two bits in m.arr[3] + + if m.arr[0]|m.arr[1]|m.arr[2]|(m.arr[3]&(m.arr[3]-1)) == 0 { + + mu[4] = ^uint64(0) >> uint(p&63) + mu[3] = ^uint64(0) + mu[2] = ^uint64(0) + mu[1] = ^uint64(0) + mu[0] = ^uint64(0) + + return mu + } + + // Maximise division precision by left-aligning divisor + + var ( + y Uint // left-aligned copy of m + r0 uint32 // estimate of 2^31/y + ) + + y.Lsh(m, uint(s)) // 1/2 < y < 1 + + // Extract most significant 32 bits + + yh := uint32(y.arr[3] >> 32) + + if yh == 0x80000000 { // Avoid overflow in division + r0 = 0xffffffff + } else { + r0, _ = bits.Div32(0x80000000, 0, yh) + } + + // First iteration: 32 -> 64 + + t1 := uint64(r0) // 2^31/y + t1 *= t1 // 2^62/y^2 + t1, _ = bits.Mul64(t1, y.arr[3]) // 2^62/y^2 * 2^64/y / 2^64 = 2^62/y + + r1 := uint64(r0) << 32 // 2^63/y + r1 -= t1 // 2^63/y - 2^62/y = 2^62/y + r1 *= 2 // 2^63/y + + if (r1 | (y.arr[3] << 1)) == 0 { + r1 = ^uint64(0) + } + + // Second iteration: 64 -> 128 + + // square: 2^126/y^2 + a2h, a2l := bits.Mul64(r1, r1) + + // multiply by y: e2h:e2l:b2h = 2^126/y^2 * 2^128/y / 2^128 = 2^126/y + b2h, _ := bits.Mul64(a2l, y.arr[2]) + c2h, c2l := bits.Mul64(a2l, y.arr[3]) + d2h, d2l := bits.Mul64(a2h, y.arr[2]) + e2h, e2l := bits.Mul64(a2h, y.arr[3]) + + b2h, c := bits.Add64(b2h, c2l, 0) + e2l, c = bits.Add64(e2l, c2h, c) + e2h, _ = bits.Add64(e2h, 0, c) + + _, c = bits.Add64(b2h, d2l, 0) + e2l, c = bits.Add64(e2l, d2h, c) + e2h, _ = bits.Add64(e2h, 0, c) + + // subtract: t2h:t2l = 2^127/y - 2^126/y = 2^126/y + t2l, b := bits.Sub64(0, e2l, 0) + t2h, _ := bits.Sub64(r1, e2h, b) + + // double: r2h:r2l = 2^127/y + r2l, c := bits.Add64(t2l, t2l, 0) + r2h, _ := bits.Add64(t2h, t2h, c) + + if (r2h | r2l | (y.arr[3] << 1)) == 0 { + r2h = ^uint64(0) + r2l = ^uint64(0) + } + + // Third iteration: 128 -> 192 + + // square r2 (keep 256 bits): 2^190/y^2 + a3h, a3l := bits.Mul64(r2l, r2l) + b3h, b3l := bits.Mul64(r2l, r2h) + c3h, c3l := bits.Mul64(r2h, r2h) + + a3h, c = bits.Add64(a3h, b3l, 0) + c3l, c = bits.Add64(c3l, b3h, c) + c3h, _ = bits.Add64(c3h, 0, c) + + a3h, c = bits.Add64(a3h, b3l, 0) + c3l, c = bits.Add64(c3l, b3h, c) + c3h, _ = bits.Add64(c3h, 0, c) + + // multiply by y: q = 2^190/y^2 * 2^192/y / 2^192 = 2^190/y + + x0 := a3l + x1 := a3h + x2 := c3l + x3 := c3h + + var q0, q1, q2, q3, q4, t0 uint64 + + q0, _ = bits.Mul64(x2, y.arr[0]) + q1, t0 = bits.Mul64(x3, y.arr[0]) + q0, c = bits.Add64(q0, t0, 0) + q1, _ = bits.Add64(q1, 0, c) + + t1, _ = bits.Mul64(x1, y.arr[1]) + q0, c = bits.Add64(q0, t1, 0) + q2, t0 = bits.Mul64(x3, y.arr[1]) + q1, c = bits.Add64(q1, t0, c) + q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x2, y.arr[1]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[2]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + q3, t0 = bits.Mul64(x3, y.arr[2]) + q2, c = bits.Add64(q2, t0, c) + q3, _ = bits.Add64(q3, 0, c) + + t1, _ = bits.Mul64(x0, y.arr[2]) + q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x2, y.arr[2]) + q1, c = bits.Add64(q1, t0, c) + q2, c = bits.Add64(q2, t1, c) + q3, _ = bits.Add64(q3, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[3]) + q1, c = bits.Add64(q1, t0, 0) + q2, c = bits.Add64(q2, t1, c) + q4, t0 = bits.Mul64(x3, y.arr[3]) + q3, c = bits.Add64(q3, t0, c) + q4, _ = bits.Add64(q4, 0, c) + + t1, t0 = bits.Mul64(x0, y.arr[3]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x2, y.arr[3]) + q2, c = bits.Add64(q2, t0, c) + q3, c = bits.Add64(q3, t1, c) + q4, _ = bits.Add64(q4, 0, c) + + // subtract: t3 = 2^191/y - 2^190/y = 2^190/y + _, b = bits.Sub64(0, q0, 0) + _, b = bits.Sub64(0, q1, b) + t3l, b := bits.Sub64(0, q2, b) + t3m, b := bits.Sub64(r2l, q3, b) + t3h, _ := bits.Sub64(r2h, q4, b) + + // double: r3 = 2^191/y + r3l, c := bits.Add64(t3l, t3l, 0) + r3m, c := bits.Add64(t3m, t3m, c) + r3h, _ := bits.Add64(t3h, t3h, c) + + // Fourth iteration: 192 -> 320 + + // square r3 + + a4h, a4l := bits.Mul64(r3l, r3l) + b4h, b4l := bits.Mul64(r3l, r3m) + c4h, c4l := bits.Mul64(r3l, r3h) + d4h, d4l := bits.Mul64(r3m, r3m) + e4h, e4l := bits.Mul64(r3m, r3h) + f4h, f4l := bits.Mul64(r3h, r3h) + + b4h, c = bits.Add64(b4h, c4l, 0) + e4l, c = bits.Add64(e4l, c4h, c) + e4h, _ = bits.Add64(e4h, 0, c) + + a4h, c = bits.Add64(a4h, b4l, 0) + d4l, c = bits.Add64(d4l, b4h, c) + d4h, c = bits.Add64(d4h, e4l, c) + f4l, c = bits.Add64(f4l, e4h, c) + f4h, _ = bits.Add64(f4h, 0, c) + + a4h, c = bits.Add64(a4h, b4l, 0) + d4l, c = bits.Add64(d4l, b4h, c) + d4h, c = bits.Add64(d4h, e4l, c) + f4l, c = bits.Add64(f4l, e4h, c) + f4h, _ = bits.Add64(f4h, 0, c) + + // multiply by y + + x1, x0 = bits.Mul64(d4h, y.arr[0]) + x3, x2 = bits.Mul64(f4h, y.arr[0]) + t1, t0 = bits.Mul64(f4l, y.arr[0]) + x1, c = bits.Add64(x1, t0, 0) + x2, c = bits.Add64(x2, t1, c) + x3, _ = bits.Add64(x3, 0, c) + + t1, t0 = bits.Mul64(d4h, y.arr[1]) + x1, c = bits.Add64(x1, t0, 0) + x2, c = bits.Add64(x2, t1, c) + x4, t0 := bits.Mul64(f4h, y.arr[1]) + x3, c = bits.Add64(x3, t0, c) + x4, _ = bits.Add64(x4, 0, c) + t1, t0 = bits.Mul64(d4l, y.arr[1]) + x0, c = bits.Add64(x0, t0, 0) + x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(f4l, y.arr[1]) + x2, c = bits.Add64(x2, t0, c) + x3, c = bits.Add64(x3, t1, c) + x4, _ = bits.Add64(x4, 0, c) + + t1, t0 = bits.Mul64(a4h, y.arr[2]) + x0, c = bits.Add64(x0, t0, 0) + x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(d4h, y.arr[2]) + x2, c = bits.Add64(x2, t0, c) + x3, c = bits.Add64(x3, t1, c) + x5, t0 := bits.Mul64(f4h, y.arr[2]) + x4, c = bits.Add64(x4, t0, c) + x5, _ = bits.Add64(x5, 0, c) + t1, t0 = bits.Mul64(d4l, y.arr[2]) + x1, c = bits.Add64(x1, t0, 0) + x2, c = bits.Add64(x2, t1, c) + t1, t0 = bits.Mul64(f4l, y.arr[2]) + x3, c = bits.Add64(x3, t0, c) + x4, c = bits.Add64(x4, t1, c) + x5, _ = bits.Add64(x5, 0, c) + + t1, t0 = bits.Mul64(a4h, y.arr[3]) + x1, c = bits.Add64(x1, t0, 0) + x2, c = bits.Add64(x2, t1, c) + t1, t0 = bits.Mul64(d4h, y.arr[3]) + x3, c = bits.Add64(x3, t0, c) + x4, c = bits.Add64(x4, t1, c) + x6, t0 := bits.Mul64(f4h, y.arr[3]) + x5, c = bits.Add64(x5, t0, c) + x6, _ = bits.Add64(x6, 0, c) + t1, t0 = bits.Mul64(a4l, y.arr[3]) + x0, c = bits.Add64(x0, t0, 0) + x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(d4l, y.arr[3]) + x2, c = bits.Add64(x2, t0, c) + x3, c = bits.Add64(x3, t1, c) + t1, t0 = bits.Mul64(f4l, y.arr[3]) + x4, c = bits.Add64(x4, t0, c) + x5, c = bits.Add64(x5, t1, c) + x6, _ = bits.Add64(x6, 0, c) + + // subtract + _, b = bits.Sub64(0, x0, 0) + _, b = bits.Sub64(0, x1, b) + r4l, b := bits.Sub64(0, x2, b) + r4k, b := bits.Sub64(0, x3, b) + r4j, b := bits.Sub64(r3l, x4, b) + r4i, b := bits.Sub64(r3m, x5, b) + r4h, _ := bits.Sub64(r3h, x6, b) + + // Multiply candidate for 1/4y by y, with full precision + + x0 = r4l + x1 = r4k + x2 = r4j + x3 = r4i + x4 = r4h + + q1, q0 = bits.Mul64(x0, y.arr[0]) + q3, q2 = bits.Mul64(x2, y.arr[0]) + q5, q4 := bits.Mul64(x4, y.arr[0]) + + t1, t0 = bits.Mul64(x1, y.arr[0]) + q1, c = bits.Add64(q1, t0, 0) + q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x3, y.arr[0]) + q3, c = bits.Add64(q3, t0, c) + q4, c = bits.Add64(q4, t1, c) + q5, _ = bits.Add64(q5, 0, c) + + t1, t0 = bits.Mul64(x0, y.arr[1]) + q1, c = bits.Add64(q1, t0, 0) + q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x2, y.arr[1]) + q3, c = bits.Add64(q3, t0, c) + q4, c = bits.Add64(q4, t1, c) + q6, t0 := bits.Mul64(x4, y.arr[1]) + q5, c = bits.Add64(q5, t0, c) + q6, _ = bits.Add64(q6, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[1]) + q2, c = bits.Add64(q2, t0, 0) + q3, c = bits.Add64(q3, t1, c) + t1, t0 = bits.Mul64(x3, y.arr[1]) + q4, c = bits.Add64(q4, t0, c) + q5, c = bits.Add64(q5, t1, c) + q6, _ = bits.Add64(q6, 0, c) + + t1, t0 = bits.Mul64(x0, y.arr[2]) + q2, c = bits.Add64(q2, t0, 0) + q3, c = bits.Add64(q3, t1, c) + t1, t0 = bits.Mul64(x2, y.arr[2]) + q4, c = bits.Add64(q4, t0, c) + q5, c = bits.Add64(q5, t1, c) + q7, t0 := bits.Mul64(x4, y.arr[2]) + q6, c = bits.Add64(q6, t0, c) + q7, _ = bits.Add64(q7, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[2]) + q3, c = bits.Add64(q3, t0, 0) + q4, c = bits.Add64(q4, t1, c) + t1, t0 = bits.Mul64(x3, y.arr[2]) + q5, c = bits.Add64(q5, t0, c) + q6, c = bits.Add64(q6, t1, c) + q7, _ = bits.Add64(q7, 0, c) + + t1, t0 = bits.Mul64(x0, y.arr[3]) + q3, c = bits.Add64(q3, t0, 0) + q4, c = bits.Add64(q4, t1, c) + t1, t0 = bits.Mul64(x2, y.arr[3]) + q5, c = bits.Add64(q5, t0, c) + q6, c = bits.Add64(q6, t1, c) + q8, t0 := bits.Mul64(x4, y.arr[3]) + q7, c = bits.Add64(q7, t0, c) + q8, _ = bits.Add64(q8, 0, c) + + t1, t0 = bits.Mul64(x1, y.arr[3]) + q4, c = bits.Add64(q4, t0, 0) + q5, c = bits.Add64(q5, t1, c) + t1, t0 = bits.Mul64(x3, y.arr[3]) + q6, c = bits.Add64(q6, t0, c) + q7, c = bits.Add64(q7, t1, c) + q8, _ = bits.Add64(q8, 0, c) + + // Final adjustment + + // subtract q from 1/4 + _, b = bits.Sub64(0, q0, 0) + _, b = bits.Sub64(0, q1, b) + _, b = bits.Sub64(0, q2, b) + _, b = bits.Sub64(0, q3, b) + _, b = bits.Sub64(0, q4, b) + _, b = bits.Sub64(0, q5, b) + _, b = bits.Sub64(0, q6, b) + _, b = bits.Sub64(0, q7, b) + _, b = bits.Sub64(uint64(1)<<62, q8, b) + + // decrement the result + x0, t := bits.Sub64(r4l, 1, 0) + x1, t = bits.Sub64(r4k, 0, t) + x2, t = bits.Sub64(r4j, 0, t) + x3, t = bits.Sub64(r4i, 0, t) + x4, _ = bits.Sub64(r4h, 0, t) + + // commit the decrement if the subtraction underflowed (reciprocal was too large) + if b != 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + } + + // Shift to correct bit alignment, truncating excess bits + + p = (p & 63) - 1 + + x0, c = bits.Add64(r4l, r4l, 0) + x1, c = bits.Add64(r4k, r4k, c) + x2, c = bits.Add64(r4j, r4j, c) + x3, c = bits.Add64(r4i, r4i, c) + x4, _ = bits.Add64(r4h, r4h, c) + + if p < 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + p = 0 // avoid negative shift below + } + + { + r := uint(p) // right shift + l := uint(64 - r) // left shift + + x0 = (r4l >> r) | (r4k << l) + x1 = (r4k >> r) | (r4j << l) + x2 = (r4j >> r) | (r4i << l) + x3 = (r4i >> r) | (r4h << l) + x4 = (r4h >> r) + } + + if p > 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + } + + mu[0] = r4l + mu[1] = r4k + mu[2] = r4j + mu[3] = r4i + mu[4] = r4h + + return mu +} + +// reduce4 computes the least non-negative residue of x modulo m +// +// requires a four-word modulus (m.arr[3] > 1) and its inverse (mu) +func reduce4(x [8]uint64, m *Uint, mu [5]uint64) (z Uint) { + // NB: Most variable names in the comments match the pseudocode for + // Barrett reduction in the Handbook of Applied Cryptography. + + // q1 = x/2^192 + + x0 := x[3] + x1 := x[4] + x2 := x[5] + x3 := x[6] + x4 := x[7] + + // q2 = q1 * mu; q3 = q2 / 2^320 + + var q0, q1, q2, q3, q4, q5, t0, t1, c uint64 + + q0, _ = bits.Mul64(x3, mu[0]) + q1, t0 = bits.Mul64(x4, mu[0]) + q0, c = bits.Add64(q0, t0, 0) + q1, _ = bits.Add64(q1, 0, c) + + t1, _ = bits.Mul64(x2, mu[1]) + q0, c = bits.Add64(q0, t1, 0) + q2, t0 = bits.Mul64(x4, mu[1]) + q1, c = bits.Add64(q1, t0, c) + q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x3, mu[1]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x2, mu[2]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + q3, t0 = bits.Mul64(x4, mu[2]) + q2, c = bits.Add64(q2, t0, c) + q3, _ = bits.Add64(q3, 0, c) + + t1, _ = bits.Mul64(x1, mu[2]) + q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x3, mu[2]) + q1, c = bits.Add64(q1, t0, c) + q2, c = bits.Add64(q2, t1, c) + q3, _ = bits.Add64(q3, 0, c) + + t1, _ = bits.Mul64(x0, mu[3]) + q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x2, mu[3]) + q1, c = bits.Add64(q1, t0, c) + q2, c = bits.Add64(q2, t1, c) + q4, t0 = bits.Mul64(x4, mu[3]) + q3, c = bits.Add64(q3, t0, c) + q4, _ = bits.Add64(q4, 0, c) + + t1, t0 = bits.Mul64(x1, mu[3]) + q0, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x3, mu[3]) + q2, c = bits.Add64(q2, t0, c) + q3, c = bits.Add64(q3, t1, c) + q4, _ = bits.Add64(q4, 0, c) + + t1, t0 = bits.Mul64(x0, mu[4]) + _, c = bits.Add64(q0, t0, 0) + q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x2, mu[4]) + q2, c = bits.Add64(q2, t0, c) + q3, c = bits.Add64(q3, t1, c) + q5, t0 = bits.Mul64(x4, mu[4]) + q4, c = bits.Add64(q4, t0, c) + q5, _ = bits.Add64(q5, 0, c) + + t1, t0 = bits.Mul64(x1, mu[4]) + q1, c = bits.Add64(q1, t0, 0) + q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x3, mu[4]) + q3, c = bits.Add64(q3, t0, c) + q4, c = bits.Add64(q4, t1, c) + q5, _ = bits.Add64(q5, 0, c) + + // Drop the fractional part of q3 + + q0 = q1 + q1 = q2 + q2 = q3 + q3 = q4 + q4 = q5 + + // r1 = x mod 2^320 + + x0 = x[0] + x1 = x[1] + x2 = x[2] + x3 = x[3] + x4 = x[4] + + // r2 = q3 * m mod 2^320 + + var r0, r1, r2, r3, r4 uint64 + + r4, r3 = bits.Mul64(q0, m.arr[3]) + _, t0 = bits.Mul64(q1, m.arr[3]) + r4, _ = bits.Add64(r4, t0, 0) + + t1, r2 = bits.Mul64(q0, m.arr[2]) + r3, c = bits.Add64(r3, t1, 0) + _, t0 = bits.Mul64(q2, m.arr[2]) + r4, _ = bits.Add64(r4, t0, c) + + t1, t0 = bits.Mul64(q1, m.arr[2]) + r3, c = bits.Add64(r3, t0, 0) + r4, _ = bits.Add64(r4, t1, c) + + t1, r1 = bits.Mul64(q0, m.arr[1]) + r2, c = bits.Add64(r2, t1, 0) + t1, t0 = bits.Mul64(q2, m.arr[1]) + r3, c = bits.Add64(r3, t0, c) + r4, _ = bits.Add64(r4, t1, c) + + t1, t0 = bits.Mul64(q1, m.arr[1]) + r2, c = bits.Add64(r2, t0, 0) + r3, c = bits.Add64(r3, t1, c) + _, t0 = bits.Mul64(q3, m.arr[1]) + r4, _ = bits.Add64(r4, t0, c) + + t1, r0 = bits.Mul64(q0, m.arr[0]) + r1, c = bits.Add64(r1, t1, 0) + t1, t0 = bits.Mul64(q2, m.arr[0]) + r2, c = bits.Add64(r2, t0, c) + r3, c = bits.Add64(r3, t1, c) + _, t0 = bits.Mul64(q4, m.arr[0]) + r4, _ = bits.Add64(r4, t0, c) + + t1, t0 = bits.Mul64(q1, m.arr[0]) + r1, c = bits.Add64(r1, t0, 0) + r2, c = bits.Add64(r2, t1, c) + t1, t0 = bits.Mul64(q3, m.arr[0]) + r3, c = bits.Add64(r3, t0, c) + r4, _ = bits.Add64(r4, t1, c) + + // r = r1 - r2 + + var b uint64 + + r0, b = bits.Sub64(x0, r0, 0) + r1, b = bits.Sub64(x1, r1, b) + r2, b = bits.Sub64(x2, r2, b) + r3, b = bits.Sub64(x3, r3, b) + r4, b = bits.Sub64(x4, r4, b) + + // if r<0 then r+=m + + if b != 0 { + r0, c = bits.Add64(r0, m.arr[0], 0) + r1, c = bits.Add64(r1, m.arr[1], c) + r2, c = bits.Add64(r2, m.arr[2], c) + r3, c = bits.Add64(r3, m.arr[3], c) + r4, _ = bits.Add64(r4, 0, c) + } + + // while (r>=m) r-=m + + for { + // q = r - m + q0, b = bits.Sub64(r0, m.arr[0], 0) + q1, b = bits.Sub64(r1, m.arr[1], b) + q2, b = bits.Sub64(r2, m.arr[2], b) + q3, b = bits.Sub64(r3, m.arr[3], b) + q4, b = bits.Sub64(r4, 0, b) + + // if borrow break + if b != 0 { + break + } + + // r = q + r4, r3, r2, r1, r0 = q4, q3, q2, q1, q0 + } + + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = r3, r2, r1, r0 + + return z +} diff --git a/contract/p/gnoswap/uint256/uint256.gno b/contract/p/gnoswap/uint256/uint256.gno new file mode 100644 index 000000000..35a399f22 --- /dev/null +++ b/contract/p/gnoswap/uint256/uint256.gno @@ -0,0 +1,291 @@ +// Ported from https://github.com/holiman/uint256 +// This package provides a 256-bit unsigned integer type, Uint256, and associated functions. +package uint256 + +import ( + "errors" + "strconv" + "math/bits" +) + +const ( + MaxUint64 = 1<<64 - 1 +) + +// Uint is represented as an array of 4 uint64, in little-endian order, +// so that Uint[3] is the most significant, and Uint[0] is the least significant +type Uint struct { + arr [4]uint64 +} + +// NewUint returns a new initialized Uint. +func NewUint(val uint64) *Uint { + z := &Uint{arr: [4]uint64{val, 0, 0, 0}} + return z +} + +// Zero returns a new Uint initialized to zero. +func Zero() *Uint { + return NewUint(0) +} + +// One returns a new Uint initialized to one. +func One() *Uint { + return NewUint(1) +} + +// SetAllOne sets all the bits of z to 1 +func (z *Uint) SetAllOne() *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = MaxUint64, MaxUint64, MaxUint64, MaxUint64 + return z +} + +// Set sets z to x and returns z. +func (z *Uint) Set(x *Uint) *Uint { + *z = *x + + return z +} + +// SetOne sets z to 1 +func (z *Uint) SetOne() *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, 1 + return z +} + +const twoPow256Sub1 = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + +// SetFromDecimal sets z from the given string, interpreted as a decimal number. +// OBS! This method is _not_ strictly identical to the (*big.Uint).SetString(..., 10) method. +// Notable differences: +// - This method does not accept underscore input, e.g. "100_000", +// - This method does not accept negative zero as valid, e.g "-0", +// - (this method does not accept any negative input as valid)) +func (z *Uint) SetFromDecimal(s string) (err error) { + // Remove max one leading + + if len(s) > 0 && s[0] == '+' { + s = s[1:] + } + // Remove any number of leading zeroes + if len(s) > 0 && s[0] == '0' { + var i int + var c rune + for i, c = range s { + if c != '0' { + break + } + } + s = s[i:] + } + if len(s) < len(twoPow256Sub1) { + return z.fromDecimal(s) + } + if len(s) == len(twoPow256Sub1) { + if s > twoPow256Sub1 { + return ErrBig256Range + } + return z.fromDecimal(s) + } + return ErrBig256Range +} + +// FromDecimal is a convenience-constructor to create an Uint from a +// decimal (base 10) string. Numbers larger than 256 bits are not accepted. +func FromDecimal(decimal string) (*Uint, error) { + var z Uint + if err := z.SetFromDecimal(decimal); err != nil { + return nil, err + } + return &z, nil +} + +// MustFromDecimal is a convenience-constructor to create an Uint from a +// decimal (base 10) string. +// Returns a new Uint and panics if any error occurred. +func MustFromDecimal(decimal string) *Uint { + var z Uint + if err := z.SetFromDecimal(decimal); err != nil { + panic(err) + } + return &z +} + +// multipliers holds the values that are needed for fromDecimal +var multipliers = [5]*Uint{ + nil, // represents first round, no multiplication needed + {[4]uint64{10000000000000000000, 0, 0, 0}}, // 10 ^ 19 + {[4]uint64{687399551400673280, 5421010862427522170, 0, 0}}, // 10 ^ 38 + {[4]uint64{5332261958806667264, 17004971331911604867, 2938735877055718769, 0}}, // 10 ^ 57 + {[4]uint64{0, 8607968719199866880, 532749306367912313, 1593091911132452277}}, // 10 ^ 76 +} + +// fromDecimal is a helper function to only ever be called via SetFromDecimal +// this function takes a string and chunks it up, calling ParseUint on it up to 5 times +// these chunks are then multiplied by the proper power of 10, then added together. +func (z *Uint) fromDecimal(bs string) error { + // first clear the input + z.Clear() + // the maximum value of uint64 is 18446744073709551615, which is 20 characters + // one less means that a string of 19 9's is always within the uint64 limit + var ( + num uint64 + err error + remaining = len(bs) + ) + if remaining == 0 { + return errors.New("EOF") + } + // We proceed in steps of 19 characters (nibbles), from least significant to most significant. + // This means that the first (up to) 19 characters do not need to be multiplied. + // In the second iteration, our slice of 19 characters needs to be multipleied + // by a factor of 10^19. Et cetera. + for i, mult := range multipliers { + if remaining <= 0 { + return nil // Done + } else if remaining > 19 { + num, err = strconv.ParseUint(bs[remaining-19:remaining], 10, 64) + } else { + // Final round + num, err = strconv.ParseUint(bs, 10, 64) + } + if err != nil { + return err + } + // add that number to our running total + if i == 0 { + z.SetUint64(num) + } else { + base := NewUint(num) + z.Add(z, base.Mul(base, mult)) + } + // Chop off another 19 characters + if remaining > 19 { + bs = bs[0 : remaining-19] + } + remaining -= 19 + } + return nil +} + +// Byte sets z to the value of the byte at position n, +// with 'z' considered as a big-endian 32-byte integer +// if 'n' > 32, f is set to 0 +// Example: f = '5', n=31 => 5 +func (z *Uint) Byte(n *Uint) *Uint { + // in z, z.arr[0] is the least significant + if number, overflow := n.Uint64WithOverflow(); !overflow { + if number < 32 { + number := z.arr[4-1-number/8] + offset := (n.arr[0] & 0x7) << 3 // 8*(n.d % 8) + z.arr[0] = (number & (0xff00000000000000 >> offset)) >> (56 - offset) + z.arr[3], z.arr[2], z.arr[1] = 0, 0, 0 + return z + } + } + + return z.Clear() +} + +// BitLen returns the number of bits required to represent z +func (z *Uint) BitLen() int { + switch { + case z.arr[3] != 0: + return 192 + bits.Len64(z.arr[3]) + case z.arr[2] != 0: + return 128 + bits.Len64(z.arr[2]) + case z.arr[1] != 0: + return 64 + bits.Len64(z.arr[1]) + default: + return bits.Len64(z.arr[0]) + } +} + +// ByteLen returns the number of bytes required to represent z +func (z *Uint) ByteLen() int { + return (z.BitLen() + 7) / 8 +} + +// Clear sets z to 0 +func (z *Uint) Clear() *Uint { + z.arr[3], z.arr[2], z.arr[1], z.arr[0] = 0, 0, 0, 0 + return z +} + +const ( + // hextable = "0123456789abcdef" + bintable = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00\x01\x02\x03\x04\x05\x06\a\b\t\xff\xff\xff\xff\xff\xff\xff\n\v\f\r\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\n\v\f\r\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + badNibble = 0xff +) + +// SetFromHex sets z from the given string, interpreted as a hexadecimal number. +// OBS! This method is _not_ strictly identical to the (*big.Int).SetString(..., 16) method. +// Notable differences: +// - This method _require_ "0x" or "0X" prefix. +// - This method does not accept zero-prefixed hex, e.g. "0x0001" +// - This method does not accept underscore input, e.g. "100_000", +// - This method does not accept negative zero as valid, e.g "-0x0", +// - (this method does not accept any negative input as valid) +func (z *Uint) SetFromHex(hex string) error { + return z.fromHex(hex) +} + +// fromHex is the internal implementation of parsing a hex-string. +func (z *Uint) fromHex(hex string) error { + if err := checkNumberS(hex); err != nil { + return err + } + if len(hex) > 66 { + return ErrBig256Range + } + z.Clear() + end := len(hex) + for i := 0; i < 4; i++ { + start := end - 16 + if start < 2 { + start = 2 + } + for ri := start; ri < end; ri++ { + nib := bintable[hex[ri]] + if nib == badNibble { + return ErrSyntax + } + z.arr[i] = z.arr[i] << 4 + z.arr[i] += uint64(nib) + } + end = start + } + return nil +} + +// FromHex is a convenience-constructor to create an Uint from +// a hexadecimal string. The string is required to be '0x'-prefixed +// Numbers larger than 256 bits are not accepted. +func FromHex(hex string) (*Uint, error) { + var z Uint + if err := z.fromHex(hex); err != nil { + return nil, err + } + return &z, nil +} + +// MustFromHex is a convenience-constructor to create an Uint from +// a hexadecimal string. +// Returns a new Uint and panics if any error occurred. +func MustFromHex(hex string) *Uint { + var z Uint + if err := z.fromHex(hex); err != nil { + panic(err) + } + return &z +} + +// Clone creates a new Uint identical to z +func (z *Uint) Clone() *Uint { + var x Uint + x.arr[0] = z.arr[0] + x.arr[1] = z.arr[1] + x.arr[2] = z.arr[2] + x.arr[3] = z.arr[3] + + return &x +} diff --git a/contract/p/gnoswap/uint256/uint256_test.gno b/contract/p/gnoswap/uint256/uint256_test.gno new file mode 100644 index 000000000..c9646b812 --- /dev/null +++ b/contract/p/gnoswap/uint256/uint256_test.gno @@ -0,0 +1,321 @@ +package uint256 + +import ( + "testing" +) + +func TestMulDiv_1(t *testing.T) { + // reverts if denominator is 0 + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + + x := NewUint(5) + y := Zero() + + shouldPanic( + t, + func() { + MulDiv(Q128, x, y) + }, + ) +} + +func TestMulDiv_2(t *testing.T) { + // reverts if denominator is 0 and numerator overflows + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + x := NewUint(5) + y := Zero() + + shouldPanic( + t, + func() { + MulDiv(Q128, Q128, y) + }, + ) +} + +func TestMulDiv_3(t *testing.T) { + // reverts if output overflows uint256 + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + y := One() + + shouldPanic( + t, + func() { + MulDiv(Q128, Q128, y) + }, + ) +} + +func TestMulDiv_4(t *testing.T) { + // reverts on overflow with all max inputs + MaxUint256 := MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + MaxUint256Sub1 := MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639934") + if MaxUint256.ToString() != "115792089237316195423570985008687907853269984665640564039457584007913129639935" { + t.Errorf("MustFromHex is Failed") + } + shouldPanic( + t, + func() { + MulDiv(MaxUint256, MaxUint256, MaxUint256Sub1) + }, + ) +} + +func TestMulDiv_5(t *testing.T) { + // all max inputs + MaxUint256 := MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + + got := MulDiv(MaxUint256, MaxUint256, MaxUint256) + + rst := got.Eq(MaxUint256) + if !rst { + t.Errorf("MaxUint256 * MaxUint256 / MaxUint256 is not same to MaxUint256") + } +} + +func TestMulDiv_6(t *testing.T) { + denom := NewUint(100) + + // accurate without phantom overflow #1 + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + Q128Const := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + expected := Q128.Div(Q128, NewUint(3)) + x := NewUint(50) + x1 := MulDiv(x, Q128Const, denom) + y := NewUint(150) + y1 := MulDiv(y, Q128Const, denom) + + got := MulDiv(Q128Const, x1, y1) + + rst := got.Eq(expected) + if !rst { + t.Errorf("Q128/3 is not smae to Q128 * (50*Q128/100) / (150*Q128/100)") + } +} + +func TestMulDiv_6_1(t *testing.T) { + denom := NewUint(100) + + // accurate without phantom overflow #1_1 + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + expected := new(Uint).Div(Q128, NewUint(3)) + x := NewUint(50) + x1 := MulDiv(x, Q128, denom) + y := NewUint(150) + y1 := MulDiv(y, Q128, denom) + + got := MulDiv(Q128, x1, y1) + + rst := got.Eq(expected) + if !rst { + t.Errorf("Q128/3 is not smae to Q128 * (50*Q128/100) / (150*Q128/100)") + } +} + +func TestMulDiv_7(t *testing.T) { + // accurate with phantom overflow #2 + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + Q128Const := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + x := NewUint(4375) + expected := MulDiv(x, Q128Const, NewUint(1000)) + + y1 := NewUint(35) + y1.Mul(y1, Q128Const) + + denom := NewUint(8) + denom.Mul(denom, Q128Const) + + got := MulDiv(Q128Const, y1, denom) + + rst := got.Eq(expected) + if !rst { + t.Errorf("4375*Q128/1000 is not smae to Q128*35*Q128/(8*Q128)") + } +} + +func TestMulDiv_8(t *testing.T) { + // accurate with phantom overflow and repeating decimal + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + Q128Const := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + expected := MulDiv(One(), Q128Const, NewUint(3)) + + y1 := NewUint(1000) + y1.Mul(y1, Q128Const) + + denom := NewUint(3000) + denom.Mul(denom, Q128Const) + + got := MulDiv(Q128Const, y1, denom) + + rst := got.Eq(expected) + if !rst { + t.Errorf("1*Q128/3 is not smae to Q128*1000*Q128/(3000*Q128)") + } +} + +func TestMulDivRoundingUp_1(t *testing.T) { + // reverts if denominator is 0 + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + + x := NewUint(5) + y := Zero() + + shouldPanic( + t, + func() { + MulDivRoundingUp(Q128, x, y) + }, + ) +} + +func TestMulDivRoundingUp_2(t *testing.T) { + // reverts if denominator is 0 and numerator overflows + + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + y := Zero() + + shouldPanic( + t, + func() { + MulDivRoundingUp(Q128, Q128, y) + }, + ) +} + +func TestMulDivRoundingUp_3(t *testing.T) { + // reverts if output overflows uint256 + + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + y := One() + + shouldPanic( + t, + func() { + MulDivRoundingUp(Q128, Q128, y) + }, + ) +} + +func TestMulDivRoundingUp_4(t *testing.T) { + // reverts on overflow with all max inputs + + MaxUint256 := MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + MaxUint256Sub1 := MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639934") + + shouldPanic( + t, + func() { + MulDivRoundingUp(MaxUint256, MaxUint256, MaxUint256Sub1) + }, + ) +} + +func TestMulDivRoundingUp_5(t *testing.T) { + // reverts if mulDiv overflows 256 bits after rounding up + + x := MustFromDecimal("535006138814359") + y := MustFromDecimal("432862656469423142931042426214547535783388063929571229938474969") + + shouldPanic( + t, + func() { + MulDivRoundingUp(x, y, NewUint(2)) + }, + ) +} + +func TestMulDivRoundingUp_6(t *testing.T) { + // reverts if mulDiv overflows 256 bits after rounding up case 2 + + x := MustFromDecimal("115792089237316195423570985008687907853269984659341747863450311749907997002549") + y := MustFromDecimal("115792089237316195423570985008687907853269984659341747863450311749907997002550") + z := MustFromDecimal("115792089237316195423570985008687907853269984653042931687443039491902864365164") + + shouldPanic( + t, + func() { + MulDivRoundingUp(x, y, z) + }, + ) +} + +func TestMulDivRoundingUp_7(t *testing.T) { + // all max inputs + MaxUint256 := MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + + got := MulDivRoundingUp(MaxUint256, MaxUint256, MaxUint256) + + rst := got.Eq(MaxUint256) + + if !rst { + t.Errorf("MaxUin256*MaxUint256/MaxUint256 RoudingUp is not same to MaxUint256") + } +} + +func TestMulDivRoundingUp_8(t *testing.T) { + // accurate without phantom overflow + denom := NewUint(100) + + // accurate without phantom overflow #1 + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + Q128Const := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + Q128.Div(Q128, NewUint(3)) + expected := Q128.Add(Q128, NewUint(1)) + + x1 := NewUint(50) + x1.Mul(x1, Q128Const) + + y1 := NewUint(150) + y1.Mul(y1, Q128Const) + + got := MulDivRoundingUp(Q128Const, x1, y1) + + rst := got.Eq(expected) + if !rst { + t.Errorf("Q128*50*Q128/100/(150*Q128/100) should be equal to Q128/3") + } +} + +func TestMulDivRoundingUp_9(t *testing.T) { + // accurate with phantom overflow #2 + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + Q128Const := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + x := NewUint(4375) + expected := MulDiv(x, Q128Const, NewUint(1000)) + + y1 := NewUint(35) + y1.Mul(y1, Q128Const) + + denom := NewUint(8) + denom.Mul(denom, Q128Const) + + got := MulDiv(Q128Const, y1, denom) + + rst := got.Eq(expected) + if !rst { + t.Errorf("4375*Q128/1000 is not smae to Q128*35*Q128/(8*Q128)") + } +} + +func TestMulDivRoundingUp_10(t *testing.T) { + // accurate with phantom overflow and repeating decimal + + Q128 := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + Q128Const := MustFromDecimal("340282366920938463463374607431768211456") // 2**128 + + expected := MulDiv(One(), Q128Const, NewUint(3)) + expected.Add(expected, One()) + + y1 := NewUint(1000) + y1.Mul(y1, Q128Const) + + denom := NewUint(3000) + denom.Mul(denom, Q128Const) + + got := MulDivRoundingUp(Q128Const, y1, denom) + + rst := got.Eq(expected) + if !rst { + t.Errorf("Q128*100*Q128/300*Q128 should be eq to 1*Q128+1") + } +} diff --git a/contract/p/gnoswap/uint256/utils.gno b/contract/p/gnoswap/uint256/utils.gno new file mode 100644 index 000000000..bcc7bb283 --- /dev/null +++ b/contract/p/gnoswap/uint256/utils.gno @@ -0,0 +1,20 @@ +package uint256 + +func checkNumberS(input string) error { + const fn = "UnmarshalText" + l := len(input) + if l == 0 { + return errEmptyString(fn, input) + } + if l < 2 || input[0] != '0' || + (input[1] != 'x' && input[1] != 'X') { + return errMissingPrefix(fn, input) + } + if l == 2 { + return errEmptyNumber(fn, input) + } + if len(input) > 3 && input[2] == '0' { + return errLeadingZero(fn, input) + } + return nil +} diff --git a/contract/r/gnoswap/common/access.gno b/contract/r/gnoswap/common/access.gno new file mode 100644 index 000000000..fe61f41a6 --- /dev/null +++ b/contract/r/gnoswap/common/access.gno @@ -0,0 +1,107 @@ +package common + +import ( + "std" + + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" +) + +const ( + ErrNoPermission = "caller(%s) has no permission" +) + +// AssertCaller checks if the caller is the given address. +func AssertCaller(caller, addr std.Address) error { + if caller != addr { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +func SatisfyCond(cond bool) error { + if !cond { + return ufmt.Errorf("given condition is not satisfied the permission check") + } + return nil +} + +// AdminOnly checks if the caller is the admin. +func AdminOnly(caller std.Address) error { + if caller != consts.ADMIN { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// GovernanceOnly checks if the caller is the gov governance contract. +func GovernanceOnly(caller std.Address) error { + if caller != consts.GOV_GOVERNANCE_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// GovStakerOnly checks if the caller is the gov staker contract. +func GovStakerOnly(caller std.Address) error { + if caller != consts.GOV_STAKER_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// RouterOnly checks if the caller is the router contract. +func RouterOnly(caller std.Address) error { + if caller != consts.ROUTER_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// PoolOnly checks if the caller is the pool contract. +func PoolOnly(caller std.Address) error { + if caller != consts.POOL_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// PositionOnly checks if the caller is the position contract. +func PositionOnly(caller std.Address) error { + if caller != consts.POSITION_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// StakerOnly checks if the caller is the staker contract. +func StakerOnly(caller std.Address) error { + if caller != consts.STAKER_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// LaunchpadOnly checks if the caller is the launchpad contract. +func LaunchpadOnly(caller std.Address) error { + if caller != consts.LAUNCHPAD_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// EmissionOnly checks if the caller is the emission contract. +func EmissionOnly(caller std.Address) error { + if caller != consts.EMISSION_ADDR { + return ufmt.Errorf(ErrNoPermission, caller.String()) + } + return nil +} + +// UserOnly checks if the caller is a user. +func UserOnly(prev std.Realm) error { + if !prev.IsUser() { + return ufmt.Errorf("caller(%s) is not a user", prev.PkgPath()) + } + return nil +} diff --git a/contract/r/gnoswap/common/access_test.gno b/contract/r/gnoswap/common/access_test.gno new file mode 100644 index 000000000..f00f7c7ce --- /dev/null +++ b/contract/r/gnoswap/common/access_test.gno @@ -0,0 +1,126 @@ +package common + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" +) + +var ( + addr01 = testutils.TestAddress("addr01") + addr02 = testutils.TestAddress("addr02") +) + +func TestAssertCaller(t *testing.T) { + t.Run("same caller", func(t *testing.T) { + uassert.NoError(t, AssertCaller(addr01, addr01)) + }) + + t.Run("different caller", func(t *testing.T) { + uassert.Error(t, AssertCaller(addr01, addr02)) + }) +} + +func TestSatisfyCond(t *testing.T) { + t.Run("true", func(t *testing.T) { + uassert.NoError(t, SatisfyCond(true)) + }) + + t.Run("false", func(t *testing.T) { + uassert.Error(t, SatisfyCond(false)) + }) +} + +func TestAdminOnly(t *testing.T) { + t.Run("caller is admin", func(t *testing.T) { + uassert.NoError(t, AdminOnly(consts.ADMIN)) + }) + + t.Run("caller is not admin", func(t *testing.T) { + uassert.Error(t, AdminOnly(addr01)) + }) +} + +func TestGovernanceOnly(t *testing.T) { + t.Run("caller is governance", func(t *testing.T) { + uassert.NoError(t, GovernanceOnly(consts.GOV_GOVERNANCE_ADDR)) + }) + + t.Run("caller is not governance", func(t *testing.T) { + uassert.Error(t, GovernanceOnly(addr01)) + }) +} + +func TestGovStakerOnly(t *testing.T) { + t.Run("caller is gov staker", func(t *testing.T) { + uassert.NoError(t, GovStakerOnly(consts.GOV_STAKER_ADDR)) + }) + + t.Run("caller is not gov staker", func(t *testing.T) { + uassert.Error(t, GovStakerOnly(addr01)) + }) +} + +func TestRouterOnly(t *testing.T) { + t.Run("caller is router", func(t *testing.T) { + uassert.NoError(t, RouterOnly(consts.ROUTER_ADDR)) + }) + + t.Run("caller is not router", func(t *testing.T) { + uassert.Error(t, RouterOnly(addr01)) + }) +} + +func TestPositionOnly(t *testing.T) { + t.Run("caller is position", func(t *testing.T) { + uassert.NoError(t, PositionOnly(consts.POSITION_ADDR)) + }) + + t.Run("caller is not position", func(t *testing.T) { + uassert.Error(t, PositionOnly(addr01)) + }) +} + +func TestStakerOnly(t *testing.T) { + t.Run("caller is staker", func(t *testing.T) { + uassert.NoError(t, StakerOnly(consts.STAKER_ADDR)) + }) + + t.Run("caller is not staker", func(t *testing.T) { + uassert.Error(t, StakerOnly(addr01)) + }) +} + +func TestLaunchpadOnly(t *testing.T) { + t.Run("caller is launchpad", func(t *testing.T) { + uassert.NoError(t, LaunchpadOnly(consts.LAUNCHPAD_ADDR)) + }) + + t.Run("caller is not launchpad", func(t *testing.T) { + uassert.Error(t, LaunchpadOnly(addr01)) + }) +} + +func TestEmissionOnly(t *testing.T) { + t.Run("caller is emission", func(t *testing.T) { + uassert.NoError(t, EmissionOnly(consts.EMISSION_ADDR)) + }) + + t.Run("caller is not emission", func(t *testing.T) { + uassert.Error(t, EmissionOnly(addr01)) + }) +} + +func TestUserOnly(t *testing.T) { + t.Run("caller is user", func(t *testing.T) { + uassert.NoError(t, UserOnly(std.NewUserRealm(addr01))) + }) + + t.Run("caller is not user", func(t *testing.T) { + uassert.Error(t, UserOnly(std.NewCodeRealm("gno.land/r/realm"))) + }) +} diff --git a/contract/r/gnoswap/common/address_and_username.gno b/contract/r/gnoswap/common/address_and_username.gno new file mode 100644 index 000000000..bc8de54bb --- /dev/null +++ b/contract/r/gnoswap/common/address_and_username.gno @@ -0,0 +1,29 @@ +package common + +import ( + "std" + + pusers "gno.land/p/demo/users" + "gno.land/r/demo/users" +) + +// AddrToUser converts a type from address to AddressOrName. +// It panics if the address is invalid. +func AddrToUser(addr std.Address) pusers.AddressOrName { + AssertValidAddr(addr) + return pusers.AddressOrName(addr) +} + +// UserToAddr converts a type from AddressOrName to address. +// by resolving the user through the users realms. +func UserToAddr(user pusers.AddressOrName) std.Address { + return users.Resolve(user) +} + +// AssertValidAddr checks if the given address is valid. +// It panics with a detailed error message if the address is invalid. +func AssertValidAddr(addr std.Address) { + if !addr.IsValid() { + panic(newErrorWithDetail(errInvalidAddr, addr.String())) + } +} diff --git a/contract/r/gnoswap/common/address_and_username_test.gno b/contract/r/gnoswap/common/address_and_username_test.gno new file mode 100644 index 000000000..d988f047f --- /dev/null +++ b/contract/r/gnoswap/common/address_and_username_test.gno @@ -0,0 +1,115 @@ +package common + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + pusers "gno.land/p/demo/users" +) + +func TestAddrToUser(t *testing.T) { + tests := []struct { + name string + addr std.Address + want pusers.AddressOrName + shouldPanic bool + panicMsg string + }{ + { + name: "valid address", + addr: std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5"), + want: pusers.AddressOrName("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5"), + }, + { + name: "empty address", + addr: std.Address(""), + shouldPanic: true, + panicMsg: `[GNOSWAP-COMMON-005] invalid address || `, + }, + { + name: "invalid address", + addr: std.Address("invalid"), + shouldPanic: true, + panicMsg: `[GNOSWAP-COMMON-005] invalid address || invalid`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + AddrToUser(tt.addr) + }) + } else { + got := AddrToUser(tt.addr) + if got != tt.want { + t.Errorf("AddrToUser() = %v, want %v", got, tt.want) + } + } + }) + } +} + +func TestUserToAddr(t *testing.T) { + tests := []struct { + name string + user pusers.AddressOrName + want std.Address + }{ + { + name: "address string with user type", + user: pusers.AddressOrName("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5"), + want: std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := UserToAddr(tt.user) + if got != tt.want { + t.Errorf("UserToAddr() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAssertValidAddr(t *testing.T) { + tests := []struct { + name string + addr std.Address + shouldPanic bool + panicMsg string + }{ + { + name: "valid address", + addr: std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5"), + }, + { + name: "empty address", + addr: std.Address(""), + shouldPanic: true, + panicMsg: `[GNOSWAP-COMMON-005] invalid address || `, + }, + { + name: "invalid address", + addr: std.Address("invalid"), + shouldPanic: true, + panicMsg: `[GNOSWAP-COMMON-005] invalid address || invalid`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + AssertValidAddr(tt.addr) + }) + } else { + uassert.NotPanics(t, func() { + AssertValidAddr(tt.addr) + }) + } + }) + } +} diff --git a/contract/r/gnoswap/common/errors.gno b/contract/r/gnoswap/common/errors.gno new file mode 100644 index 000000000..f96124477 --- /dev/null +++ b/contract/r/gnoswap/common/errors.gno @@ -0,0 +1,32 @@ +package common + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-COMMON-001] caller has no permission") + errHalted = errors.New("[GNOSWAP-COMMON-002] halted") + errOutOfRange = errors.New("[GNOSWAP-COMMON-003] value out of range") + errNotRegistered = errors.New("[GNOSWAP-COMMON-004] token is not registered") + errInvalidAddr = errors.New("[GNOSWAP-COMMON-005] invalid address") + errOverflow = errors.New("[GNOSWAP-COMMON-006] overflow") + errInvalidTokenId = errors.New("[GNOSWAP-COMMON-007] invalid tokenId") + errInvalidInput = errors.New("[GNOSWAP-COMMON-008] invalid input data") + errOverFlow = errors.New("[GNOSWAP-COMMON-009] overflow") + errIdenticalTicks = errors.New("[GNOSWAP-COMMON-010] identical ticks") +) + +// newErrorWithDetail appends additional context or details to an existing error message. +// +// Parameters: +// - err: The original error (error). +// - detail: Additional context or detail to append to the error message (string). +// +// Returns: +// - string: The combined error message in the format " || ". +func newErrorWithDetail(err error, detail string) string { + return ufmt.Errorf("%s || %s", err.Error(), detail).Error() +} diff --git a/contract/r/gnoswap/common/gno.mod b/contract/r/gnoswap/common/gno.mod new file mode 100644 index 000000000..ad0e3a333 --- /dev/null +++ b/contract/r/gnoswap/common/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/common diff --git a/contract/r/gnoswap/common/grc20reg_helper.gno b/contract/r/gnoswap/common/grc20reg_helper.gno new file mode 100644 index 000000000..5f521a857 --- /dev/null +++ b/contract/r/gnoswap/common/grc20reg_helper.gno @@ -0,0 +1,100 @@ +package common + +import ( + "regexp" + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ufmt" + "gno.land/r/demo/grc20reg" +) + +var ( + re = regexp.MustCompile(`\[gno\.land/r/[^\]]+\]`) +) + +// GetToken returns a grc20.Token instance +// if token is not registered, it will panic +// token instance supports following methods: +// - GetName +// - GetSymbol +// - GetDecimals +// - TotalSupply +// - KnownAccounts +// - BalanceOf +// - Allowance +// - RenderHome +func GetToken(path string) *grc20.Token { + tokenGetter := grc20reg.MustGet(path) // if token is not registered, it will panic + + return tokenGetter() +} + +// GetTokenTeller returns a grc20.Teller instance +// if token is not registered, it will panic +// teller instance supports following methods: +// - Transfer +// - Approve +// - TransferFrom +func GetTokenTeller(path string) grc20.Teller { + tokenGetter := grc20reg.MustGet(path) // if token is not registered, it will panic + token := tokenGetter() + return token.CallerTeller() +} + +// IsRegistered returns nil if token is registered to grc20reg +// otherwise, it returns an error +func IsRegistered(path string) error { + getter := grc20reg.Get(path) + if getter == nil { + return ufmt.Errorf("token(%s) is not registered to grc20reg", path) + } + return nil +} + +// MustRegistered is a helper function to check if token is registered to grc20reg +// if token is not registered, it will panic +func MustRegistered(path string) { + if err := IsRegistered(path); err != nil { + panic(newErrorWithDetail( + errNotRegistered, + ufmt.Sprintf("token(%s)", path), + )) + } +} + +// ListRegisteredTokens returns the list of registered tokens +// NOTE: +// - Unfortunate, grc20reg doesn't support this. +// - We need to parse the rendered grc20reg page to get the list of registered tokens. +func ListRegisteredTokens() []string { + render := grc20reg.Render("") + return extractTokenPathsFromRender(render) +} + +func extractTokenPathsFromRender(render string) []string { + matches := re.FindAllString(render, -1) + + tokenPaths := make([]string, 0, len(matches)) + for _, match := range matches { + tokenPath := strings.Trim(match, "[]") // Remove the brackets + tokenPaths = append(tokenPaths, tokenPath) + } + return tokenPaths +} + +// TotalSupply returns the total supply of the token +func TotalSupply(path string) uint64 { + return GetToken(path).TotalSupply() +} + +// BalanceOf returns the balance of the token for the given address +func BalanceOf(path string, addr std.Address) uint64 { + return GetToken(path).BalanceOf(addr) +} + +// Allowance returns the allowance of the token for the given owner and spender +func Allowance(path string, owner, spender std.Address) uint64 { + return GetToken(path).Allowance(owner, spender) +} diff --git a/contract/r/gnoswap/common/grc20reg_helper_test.gno b/contract/r/gnoswap/common/grc20reg_helper_test.gno new file mode 100644 index 000000000..cbb0707f2 --- /dev/null +++ b/contract/r/gnoswap/common/grc20reg_helper_test.gno @@ -0,0 +1,261 @@ +package common + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/foo20" +) + +var ( + tokenPath = "gno.land/r/demo/foo20" +) + +func TestGetToken(t *testing.T) { + t.Run("get regsitered token", func(t *testing.T) { + token := GetToken(tokenPath) + if token == nil { + t.Error("Expected non-nil token for foo20") + } + }) + + t.Run("get non registered token", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for non-registered token") + } + }() + GetToken("not_registered_token") + }) +} + +func TestTokenMethod(t *testing.T) { + token := GetToken(tokenPath) + + t.Run("GetName()", func(t *testing.T) { + uassert.Equal(t, "Foo", token.GetName()) + }) + + t.Run("GetSymbol()", func(t *testing.T) { + uassert.Equal(t, "FOO", token.GetSymbol()) + }) + + t.Run("GetDecimals()", func(t *testing.T) { + uassert.Equal(t, uint(4), token.GetDecimals()) + }) + + t.Run("TotalSupply()", func(t *testing.T) { + uassert.Equal(t, uint64(10000000000), token.TotalSupply()) + }) + + t.Run("KnownAccounts()", func(t *testing.T) { + uassert.Equal(t, int(1), token.KnownAccounts()) + }) + + t.Run("BalanceOf()", func(t *testing.T) { + uassert.Equal(t, uint64(10000000000), token.BalanceOf(std.Address("g1manfred47kzduec920z88wfr64ylksmdcedlf5"))) + }) + + t.Run("Allowance()", func(t *testing.T) { + uassert.Equal(t, uint64(0), token.Allowance(std.Address("g1manfred47kzduec920z88wfr64ylksmdcedlf5"), std.Address("g1pf6dv9fjk3rn0m4jjcne306ga4he3mzmupfjl6"))) + }) + + t.Run("RenderHome()", func(t *testing.T) { + expected := "" + expected += ufmt.Sprintf("# %s ($%s)\n\n", "Foo", "FOO") + expected += ufmt.Sprintf("* **Decimals**: %d\n", 4) + expected += ufmt.Sprintf("* **Total supply**: %d\n", 10000000000) + expected += ufmt.Sprintf("* **Known accounts**: %d\n", 1) + uassert.Equal(t, expected, token.RenderHome()) + }) +} + +func TestGetTokenTeller(t *testing.T) { + t.Run("get registered token teller", func(t *testing.T) { + teller := GetTokenTeller(tokenPath) + if teller == nil { + t.Error("Expected non-nil teller for foo20") + } + }) + + t.Run("get non registered token teller", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for non-registered token teller") + } + }() + GetTokenTeller("not_registered_teller") + }) +} + +func TestTellerMethod(t *testing.T) { + teller := GetTokenTeller(tokenPath) + token := GetToken(tokenPath) + defaultHolder := std.Address("g1manfred47kzduec920z88wfr64ylksmdcedlf5") + addr01 := testutils.TestAddress("addr01") + + t.Run("Transfer()", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(defaultHolder)) + + uassert.Equal(t, uint64(10000000000), token.BalanceOf(defaultHolder)) + + uassert.NoError(t, teller.Transfer(addr01, uint64(10000000000))) // transfer all balance to addr01 + + uassert.Equal(t, uint64(0), token.BalanceOf(defaultHolder)) + uassert.Equal(t, uint64(10000000000), token.BalanceOf(addr01)) + + uassert.Error(t, teller.Transfer(addr01, uint64(10000000000))) // not enough balance + }) + + t.Run("Approve()", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(addr01)) + uassert.NoError(t, teller.Approve(defaultHolder, uint64(500))) + uassert.Equal(t, uint64(500), token.Allowance(addr01, defaultHolder)) + }) + + t.Run("TransferFrom()", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(defaultHolder)) + uassert.NoError(t, teller.TransferFrom(addr01, defaultHolder, uint64(500))) + uassert.Equal(t, uint64(9999999500), token.BalanceOf(addr01)) + uassert.Equal(t, uint64(500), token.BalanceOf(defaultHolder)) + uassert.Equal(t, uint64(0), token.Allowance(addr01, defaultHolder)) + + uassert.Error(t, teller.TransferFrom(addr01, defaultHolder, uint64(500))) // not enough allowance + }) +} + +func TestIsRegistered(t *testing.T) { + t.Run("check if token is registered", func(t *testing.T) { + uassert.NoError(t, IsRegistered(tokenPath)) + }) + + t.Run("check if token is not registered", func(t *testing.T) { + uassert.Error(t, IsRegistered("not_registered_token")) + }) +} + +func TestMustRegistered(t *testing.T) { + t.Run("must be registered", func(t *testing.T) { + MustRegistered(tokenPath) + }) + + t.Run("panic for non registered token", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for non-registered token") + } + }() + MustRegistered("not_registered") + }) +} + +func TestListRegisteredTokens(t *testing.T) { + t.Skip("skipping tests -> can not mock grc20reg.Render() || testing extractTokenPathsFromRender() does cover this") +} + +func TestExtractTokenPathsFromRender(t *testing.T) { + var ( + wugnotPath = "gno.land/r/demo/wugnot" + gnsPath = "gno.land/r/gnoswap/v1/gns" + fooPath = "gno.land/r/onbloc/foo" + quxPath = "gno.land/r/onbloc/qux" + ) + + // NOTE: following strings are return from grc20reg.Render() + renderList := []string{ + // no registered token + `No registered token.`, + + // 1 token + `- **wrapped GNOT** - [gno.land/r/demo/wugnot](/r/demo/wugnot) - [info](/r/demo/grc20reg:gno.land/r/demo/wugnot)`, + + // 2 tokens + `- **wrapped GNOT** - [gno.land/r/demo/wugnot](/r/demo/wugnot) - [info](/r/demo/grc20reg:gno.land/r/demo/wugnot) +- **Gnoswap** - [gno.land/r/gnoswap/v1/gns](/r/gnoswap/v1/gns) - [info](/r/demo/grc20reg:gno.land/r/gnoswap/v1/gns) +`, + + // 4 tokens + `- **wrapped GNOT** - [gno.land/r/demo/wugnot](/r/demo/wugnot) - [info](/r/demo/grc20reg:gno.land/r/demo/wugnot) +- **Gnoswap** - [gno.land/r/gnoswap/v1/gns](/r/gnoswap/v1/gns) - [info](/r/demo/grc20reg:gno.land/r/gnoswap/v1/gns) +- **Baz** - [gno.land/r/onbloc/foo](/r/onbloc/foo) - [info](/r/demo/grc20reg:gno.land/r/onbloc/foo) +- **Qux** - [gno.land/r/onbloc/qux](/r/onbloc/qux) - [info](/r/demo/grc20reg:gno.land/r/onbloc/qux) +`, + } + + tests := []struct { + name string + render string + expected []string + }{ + { + name: "no registered token", + render: renderList[0], + expected: []string{}, + }, + { + name: "1 registered token", + render: renderList[1], + expected: []string{wugnotPath}, + }, + { + name: "2 registered tokens", + render: renderList[2], + expected: []string{wugnotPath, gnsPath}, + }, + { + name: "4 registered tokens", + render: renderList[3], + expected: []string{wugnotPath, gnsPath, fooPath, quxPath}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + extracted := extractTokenPathsFromRender(tt.render) + uassert.True(t, areSlicesEqual(t, tt.expected, extracted)) + }) + } +} + +func TestTotalSupply(t *testing.T) { + // result from grc2reg and (direct import/call) should be the same + uassert.Equal(t, foo20.TotalSupply(), TotalSupply(tokenPath)) +} + +func TestBalanceOf(t *testing.T) { + defaultHolder := std.Address("g1manfred47kzduec920z88wfr64ylksmdcedlf5") + + // result from grc2reg and (direct import/call) should be the same + uassert.Equal(t, foo20.BalanceOf(AddrToUser(defaultHolder)), BalanceOf(tokenPath, defaultHolder)) +} + +func TestAllowance(t *testing.T) { + owner := std.Address("g1manfred47kzduec920z88wfr64ylksmdcedlf5") + spender := std.Address("g1pf6dv9fjk3rn0m4jjcne306ga4he3mzmupfjl6") + + // result from grc2reg and (direct import/call) should be the same + uassert.Equal(t, foo20.Allowance(AddrToUser(owner), AddrToUser(spender)), Allowance(tokenPath, owner, spender)) +} + +// areSlicesEqual compares two slices of strings +func areSlicesEqual(t *testing.T, a, b []string) bool { + t.Helper() + + // Check if lengths are different + if len(a) != len(b) { + return false + } + + // Compare each element + for i := range a { + if a[i] != b[i] { + return false + } + } + + return true +} diff --git a/contract/r/gnoswap/common/grc721_token_id.gno b/contract/r/gnoswap/common/grc721_token_id.gno new file mode 100644 index 000000000..033fb4991 --- /dev/null +++ b/contract/r/gnoswap/common/grc721_token_id.gno @@ -0,0 +1,41 @@ +package common + +import ( + "strconv" + + "gno.land/p/demo/ufmt" + + "gno.land/p/demo/grc/grc721" +) + +// TokenIdFrom converts tokenId to grc721.TokenID type +// NOTE: input parameter tokenId can be string, int, uint64, or grc721.TokenID +// if tokenId is nil or not supported, it will panic +// if input type is not supported, it will panic +// input: tokenId interface{} +// output: grc721.TokenID +func TokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic(newErrorWithDetail( + errInvalidTokenId, + "can not be nil", + )) + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + estimatedType := ufmt.Sprintf("%T", tokenId) + panic(newErrorWithDetail( + errInvalidTokenId, + ufmt.Sprintf("unsupported tokenId type: %s", estimatedType), + )) + } +} diff --git a/contract/r/gnoswap/common/grc721_token_id_test.gno b/contract/r/gnoswap/common/grc721_token_id_test.gno new file mode 100644 index 000000000..6f15e955e --- /dev/null +++ b/contract/r/gnoswap/common/grc721_token_id_test.gno @@ -0,0 +1,64 @@ +package common + +import ( + "testing" + + "gno.land/p/demo/grc/grc721" +) + +func TestTokenIdFrom(t *testing.T) { + tests := []struct { + name string + input interface{} + want grc721.TokenID + wantPanic bool + }{ + { + name: "string input", + input: "123", + want: grc721.TokenID("123"), + }, + { + name: "int input", + input: 123, + want: grc721.TokenID("123"), + }, + { + name: "uint64 input", + input: uint64(123), + want: grc721.TokenID("123"), + }, + { + name: "grc721.TokenID input", + input: grc721.TokenID("123"), + want: grc721.TokenID("123"), + }, + { + name: "nil input", + input: nil, + wantPanic: true, + }, + { + name: "unsupported type (byte)", + input: []byte("123"), + wantPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.wantPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("TokenIdFrom() should have panicked") + } + }() + } + + got := TokenIdFrom(tt.input) + if got != tt.want { + t.Errorf("TokenIdFrom() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/contract/r/gnoswap/common/halt.gno b/contract/r/gnoswap/common/halt.gno new file mode 100644 index 000000000..a90c05cc4 --- /dev/null +++ b/contract/r/gnoswap/common/halt.gno @@ -0,0 +1,100 @@ +package common + +import ( + "std" + "strconv" + + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" +) + +// halted is a global flag that indicates whether the GnoSwap is currently halted. +// When true, most operations are disabled to prevent further actions. +// Default value is false, meaning the GnoSwap is active by default. +var halted bool = false + +// GetHalt returns the current halted status of the GnoSwap. +// +// Returns: +// - bool: true if the GnoSwap is halted, false otherwise. +func GetHalt() bool { + return halted +} + +// IsHalted checks if the GnoSwap is currently halted. +// If the GnoSwap is halted, the function panics with an errHalted error. +// +// Panics: +// - If the halted flag is true, indicating that the GnoSwap is inactive. +func IsHalted() { + if halted { + panic(newErrorWithDetail( + errHalted, + "GnoSwap is halted", + )) + } +} + +// SetHaltByAdmin allows an admin to set the halt status of the GnoSwap. +// Only an admin can execute this function. If a non-admin attempts to call this function, +// the function panics with an errNoPermission error. +// +// Parameters: +// - halt (bool): The new halt status to set (true to halt, false to unhalt). +// +// Panics: +// - If the caller is not an admin, the function will panic with an errNoPermission error. +func SetHaltByAdmin(halt bool) { + caller := getPrevAddr() + if err := AdminOnly(caller); err != nil { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf( + "only admin(%s) can set halt, called from %s", + consts.ADMIN, + caller, + ), + )) + } + setHalt(halt) +} + +// SetHalt allows the governance contract to set the halt status of the GnoSwap. +// Only the governance contract can execute this function through a proposal process. +// +// Parameters: +// - halt (bool): The new halt status to set (true to halt, false to unhalt). +// +// Panics: +// - If the caller is not the governance contract, the function will panic with an errNoPermission error. +func SetHalt(halt bool) { + caller := getPrevAddr() + if err := GovernanceOnly(caller); err != nil { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf( + "only governance(%s) can set halt, called from %s", + consts.GOV_GOVERNANCE_ADDR, + caller, + ), + )) + } + setHalt(halt) +} + +// setHalt updates the halted flag to the specified value. +// This is an internal function that should only be called by SetHalt or SetHaltByAdmin. +// +// Parameters: +// - halt (bool): The new halt status to set. +func setHalt(halt bool) { + halted = halt + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "setHalt", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "halt", strconv.FormatBool(halt), + ) +} diff --git a/contract/r/gnoswap/common/halt_test.gno b/contract/r/gnoswap/common/halt_test.gno new file mode 100644 index 000000000..5e2ce8259 --- /dev/null +++ b/contract/r/gnoswap/common/halt_test.gno @@ -0,0 +1,81 @@ +package common + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" +) + +var ( + adminRealm = std.NewUserRealm(consts.ADMIN) + govRealm = std.NewCodeRealm(consts.GOV_GOVERNANCE_PATH) +) + +func TestHalts(t *testing.T) { + t.Run("GetHalt() initial value", func(t *testing.T) { + uassert.False(t, GetHalt()) + }) + + t.Run("IsHalted() success", func(t *testing.T) { + uassert.NotPanics(t, IsHalted) + }) +} + +func TestSetHaltByAdmin(t *testing.T) { + t.Run("with non-admin privilege, panics", func(t *testing.T) { + uassert.PanicsWithMessage( + t, + `[GNOSWAP-COMMON-001] caller has no permission || only admin(g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d) can set halt, called from g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm`, + func() { SetHaltByAdmin(false) }, + ) + }) + + t.Run("with governance privilege, panics", func(t *testing.T) { + std.TestSetRealm(govRealm) + uassert.PanicsWithMessage( + t, + `[GNOSWAP-COMMON-001] caller has no permission || only admin(g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d) can set halt, called from g17s8w2ve7k85fwfnrk59lmlhthkjdted8whvqxd`, + func() { SetHaltByAdmin(false) }, + ) + }) + + t.Run("with admin privilege, success", func(t *testing.T) { + std.TestSetRealm(adminRealm) + + uassert.False(t, GetHalt()) + + SetHaltByAdmin(true) + uassert.True(t, GetHalt()) + }) +} + +func TestSetHalt(t *testing.T) { + t.Run("with non-governance privilege, panics", func(t *testing.T) { + uassert.PanicsWithMessage( + t, + `[GNOSWAP-COMMON-001] caller has no permission || only governance(g17s8w2ve7k85fwfnrk59lmlhthkjdted8whvqxd) can set halt, called from g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm`, + func() { SetHalt(false) }, + ) + }) + + t.Run("with admin privilege, panics", func(t *testing.T) { + std.TestSetRealm(adminRealm) + uassert.PanicsWithMessage( + t, + `[GNOSWAP-COMMON-001] caller has no permission || only governance(g17s8w2ve7k85fwfnrk59lmlhthkjdted8whvqxd) can set halt, called from g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d`, + func() { SetHalt(false) }, + ) + }) + + t.Run("with governance privilege, success", func(t *testing.T) { + std.TestSetRealm(govRealm) + + uassert.True(t, GetHalt()) + + SetHalt(false) + uassert.False(t, GetHalt()) + }) +} diff --git a/contract/r/gnoswap/common/limit_caller.gno b/contract/r/gnoswap/common/limit_caller.gno new file mode 100644 index 000000000..66c3588a7 --- /dev/null +++ b/contract/r/gnoswap/common/limit_caller.gno @@ -0,0 +1,53 @@ +package common + +import ( + "std" + "strconv" + + "gno.land/p/demo/ufmt" +) + +// limitCaller is a global boolean flag that controls whether function calls are restricted. +// Default value is true, meaning call restrictions are enabled by default. +var limitCaller bool = true + +// GetLimitCaller returns the current state of the limitCaller flag. +// If true, call restrictions are active; if false, call restrictions are disabled. +// +// Returns: +// - bool: Current state of the limitCaller (true if active, false if inactive). +func GetLimitCaller() bool { + return limitCaller +} + +// SetLimitCaller updates the limitCaller flag to either enable or disable call restrictions. +// This function can only be called by an admin. If a non-admin attempts to call this function, +// the function will panic. +// +// Parameters: +// - v (bool): The new state for the limitCaller flag (true to enable, false to disable). +// +// Panics: +// - If the caller is not an admin, the function panics with an errNoPermission error. +func SetLimitCaller(v bool) { + caller := getPrevAddr() + if err := AdminOnly(caller); err != nil { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf( + "only Admin can set halt, called from %s", + caller, + )), + ) + } + + limitCaller = v + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "SetLimitCaller", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "limitCaller", strconv.FormatBool(v), + ) +} diff --git a/contract/r/gnoswap/common/limit_caller_test.gno b/contract/r/gnoswap/common/limit_caller_test.gno new file mode 100644 index 000000000..b1c92ef2e --- /dev/null +++ b/contract/r/gnoswap/common/limit_caller_test.gno @@ -0,0 +1,28 @@ +package common + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" +) + +func TestSetLimitCaller(t *testing.T) { + t.Run("initial check", func(t *testing.T) { + uassert.True(t, GetLimitCaller()) + }) + + t.Run("with non-admin privilege, panics", func(t *testing.T) { + uassert.PanicsWithMessage(t, + `[GNOSWAP-COMMON-001] caller has no permission || only Admin can set halt, called from g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm`, + func() { SetLimitCaller(false) }, + ) + }) + + t.Run("with admin privilege, success", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + SetLimitCaller(false) + uassert.False(t, GetLimitCaller()) + }) +} diff --git a/contract/r/gnoswap/common/liquidity_amounts.gno b/contract/r/gnoswap/common/liquidity_amounts.gno new file mode 100644 index 000000000..76bc67cb9 --- /dev/null +++ b/contract/r/gnoswap/common/liquidity_amounts.gno @@ -0,0 +1,314 @@ +package common + +import ( + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + u256 "gno.land/p/gnoswap/uint256" +) + +// toAscendingOrder checks if the first value is greater than +// the second then swaps two values. +func toAscendingOrder(a, b *u256.Uint) (*u256.Uint, *u256.Uint) { + if a.Gt(b) { + return b, a + } + + return a, b +} + +// toUint128 ensures a *u256.Uint value fits within the uint128 range. +// +// This function validates that the given `value` is properly initialized (not nil) and checks whether +// it exceeds the maximum value of uint128. If the value exceeds the uint128 range, +// it applies a masking operation to truncate the value to fit within the uint128 limit. +// +// Parameters: +// - value: *u256.Uint, the value to be checked and possibly truncated. +// +// Returns: +// - *u256.Uint: A value guaranteed to fit within the uint128 range. +// +// Notes: +// - The function first checks if the value is not nil to avoid potential runtime errors. +// - The mask ensures that only the lower 128 bits of the value are retained. +// - If the input value is already within the uint128 range, it remains unchanged. +// - MAX_UINT128 is a constant representing `2^128 - 1`. +func toUint128(value *u256.Uint) *u256.Uint { + assertOnlyNotNil(value) + if value.Gt(u256.MustFromDecimal(consts.MAX_UINT128)) { + mask := new(u256.Uint).Lsh(u256.One(), consts.Q128_RESOLUTION) + mask = new(u256.Uint).Sub(mask, u256.One()) + value = value.And(value, mask) + } + return value +} + +// safeConvertToUint128 safely ensures a *u256.Uint value fits within the uint128 range. +// +// This function verifies that the provided unsigned 256-bit integer does not exceed the maximum value for uint128 (`2^128 - 1`). +// If the value is within the uint128 range, it is returned as is; otherwise, the function triggers a panic. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be checked. +// +// Returns: +// - *u256.Uint: The same value if it is within the uint128 range. +// +// Panics: +// - If the value exceeds the maximum uint128 value (`2^128 - 1`), the function will panic with a descriptive error +// indicating the overflow and the original value. +// +// Notes: +// - The constant `MAX_UINT128` is defined as `340282366920938463463374607431768211455` (the largest uint128 value). +// - No actual conversion occurs since the function works directly with *u256.Uint types. +// +// Example: +// validUint128 := safeConvertToUint128(u256.MustFromDecimal("340282366920938463463374607431768211455")) // Valid +// safeConvertToUint128(u256.MustFromDecimal("340282366920938463463374607431768211456")) // Panics due to overflow +func safeConvertToUint128(value *u256.Uint) *u256.Uint { + if value.Gt(u256.MustFromDecimal(consts.MAX_UINT128)) { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows uint128 range", + errOverFlow, value.ToString())) + } + return value +} + +// computeLiquidityForAmount0 calculates the liquidity for a given amount of token0. +// +// This function computes the maximum possible liquidity that can be provided for `token0` +// based on the provided price boundaries (sqrtRatioAX96 and sqrtRatioBX96) in Q64.96 format. +// +// Parameters: +// - sqrtRatioAX96: *u256.Uint - The square root price at the lower tick boundary (Q64.96). +// - sqrtRatioBX96: *u256.Uint - The square root price at the upper tick boundary (Q64.96). +// - amount0: *u256.Uint - The amount of token0 to be converted to liquidity. +// +// Returns: +// - *u256.Uint: The calculated liquidity, represented as an unsigned 128-bit integer (uint128). +// +// Panics: +// - If the resulting liquidity exceeds the uint128 range, `safeConvertToUint128` will trigger a panic. +func computeLiquidityForAmount0(sqrtRatioAX96, sqrtRatioBX96, amount0 *u256.Uint) *u256.Uint { + sqrtRatioAX96, sqrtRatioBX96 = toAscendingOrder(sqrtRatioAX96, sqrtRatioBX96) + intermediate := u256.MulDiv(sqrtRatioAX96, sqrtRatioBX96, u256.MustFromDecimal(consts.Q96)) + + diff := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) + if diff.IsZero() { + panic(newErrorWithDetail( + errIdenticalTicks, + ufmt.Sprintf("sqrtRatioAX96 (%s) and sqrtRatioBX96 (%s) are identical", sqrtRatioAX96.ToString(), sqrtRatioBX96.ToString()), + )) + } + res := u256.MulDiv(amount0, intermediate, diff) + return safeConvertToUint128(res) +} + +// computeLiquidityForAmount1 calculates liquidity based on the provided token1 amount and price range. +// +// This function computes the liquidity for a given amount of token1 by using the difference +// between the upper and lower square root price ratios. The calculation uses Q96 fixed-point +// arithmetic to maintain precision. +// +// Parameters: +// - sqrtRatioAX96: *u256.Uint - The square root ratio of price at the lower tick, represented in Q96 format. +// - sqrtRatioBX96: *u256.Uint - The square root ratio of price at the upper tick, represented in Q96 format. +// - amount1: *u256.Uint - The amount of token1 to calculate liquidity for. +// +// Returns: +// - *u256.Uint: The calculated liquidity based on the provided amount of token1 and price range. +// +// Notes: +// - The result is not directly limited to uint128, as liquidity values can exceed uint128 bounds. +// - If `sqrtRatioAX96 == sqrtRatioBX96`, the function will panic due to division by zero. +// - Q96 is a constant representing `2^96`, ensuring that precision is maintained during division. +// +// Panics: +// - If the resulting liquidity exceeds the uint128 range, `safeConvertToUint128` will trigger a panic. +func computeLiquidityForAmount1(sqrtRatioAX96, sqrtRatioBX96, amount1 *u256.Uint) *u256.Uint { + sqrtRatioAX96, sqrtRatioBX96 = toAscendingOrder(sqrtRatioAX96, sqrtRatioBX96) + + diff := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) + if diff.IsZero() { + panic(newErrorWithDetail( + errIdenticalTicks, + ufmt.Sprintf("sqrtRatioAX96 (%s) and sqrtRatioBX96 (%s) are identical", sqrtRatioAX96.ToString(), sqrtRatioBX96.ToString()), + )) + } + res := u256.MulDiv(amount1, u256.MustFromDecimal(consts.Q96), diff) + return safeConvertToUint128(res) +} + +// GetLiquidityForAmounts calculates the maximum liquidity given the current price (sqrtRatioX96), +// upper and lower price bounds (sqrtRatioAX96 and sqrtRatioBX96), and token amounts (amount0, amount1). +// +// This function evaluates how much liquidity can be obtained for specified amounts of token0 and token1 +// within the provided price range. It returns the lesser liquidity based on available token0 or token1 +// to ensure the pool remains balanced. +// +// Parameters: +// - sqrtRatioX96: The current price as a square root ratio in Q64.96 format (*u256.Uint). +// - sqrtRatioAX96: The lower bound of the price range as a square root ratio in Q64.96 format (*u256.Uint). +// - sqrtRatioBX96: The upper bound of the price range as a square root ratio in Q64.96 format (*u256.Uint). +// - amount0: The amount of token0 available to provide liquidity (*u256.Uint). +// - amount1: The amount of token1 available to provide liquidity (*u256.Uint). +// +// Returns: +// - *u256.Uint: The maximum possible liquidity that can be minted. +// +// Notes: +// - The `Clone` method is used to prevent modification of the original values during computation. +// - The function ensures that liquidity calculations handle edge cases when the current price +// is outside the specified range by returning liquidity based on the dominant token. +func GetLiquidityForAmounts(sqrtRatioX96, sqrtRatioAX96, sqrtRatioBX96, amount0, amount1 *u256.Uint) *u256.Uint { + if amount0.IsZero() || amount1.IsZero() { + return u256.Zero() + } + + sqrtRatioAX96, sqrtRatioBX96 = toAscendingOrder(sqrtRatioAX96.Clone(), sqrtRatioBX96.Clone()) + var liquidity *u256.Uint + + if sqrtRatioX96.Lte(sqrtRatioAX96) { + liquidity = computeLiquidityForAmount0(sqrtRatioAX96.Clone(), sqrtRatioBX96.Clone(), amount0.Clone()) + } else if sqrtRatioX96.Lt(sqrtRatioBX96) { + liquidity0 := computeLiquidityForAmount0(sqrtRatioX96.Clone(), sqrtRatioBX96.Clone(), amount0.Clone()) + liquidity1 := computeLiquidityForAmount1(sqrtRatioAX96.Clone(), sqrtRatioX96.Clone(), amount1.Clone()) + + if liquidity0.Lt(liquidity1) { + liquidity = liquidity0 + } else { + liquidity = liquidity1 + } + } else { + liquidity = computeLiquidityForAmount1(sqrtRatioAX96.Clone(), sqrtRatioBX96.Clone(), amount1.Clone()) + } + return liquidity +} + +// computeAmount0ForLiquidity calculates the required amount of token0 for a given liquidity level +// within a specified price range (represented by sqrt ratios). +// +// This function determines the amount of token0 needed to provide a specified amount of liquidity +// within a price range defined by sqrtRatioAX96 (lower bound) and sqrtRatioBX96 (upper bound). +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the price range as a square root ratio in Q64.96 format (*u256.Uint). +// - sqrtRatioBX96: The upper bound of the price range as a square root ratio in Q64.96 format (*u256.Uint). +// - liquidity: The liquidity to be provided (*u256.Uint). +// +// Returns: +// - *u256.Uint: The amount of token0 required to achieve the specified liquidity level. +// +// Notes: +// - This function assumes the price bounds are expressed in Q64.96 fixed-point format. +// - The function returns 0 if the liquidity is 0 or the price bounds are invalid. +// - Handles edge cases where sqrtRatioAX96 equals sqrtRatioBX96 by returning 0 (to prevent division by zero). +func computeAmount0ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint) *u256.Uint { + sqrtRatioAX96, sqrtRatioBX96 = toAscendingOrder(sqrtRatioAX96, sqrtRatioBX96) + if sqrtRatioAX96.IsZero() || sqrtRatioBX96.IsZero() || liquidity.IsZero() || sqrtRatioAX96.Eq(sqrtRatioBX96) { + return u256.Zero() + } + + val1 := new(u256.Uint).Lsh(liquidity, consts.Q96_RESOLUTION) + val2 := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) + + res := u256.MulDiv(val1, val2, sqrtRatioBX96) + res = new(u256.Uint).Div(res, sqrtRatioAX96) + + return res +} + +// computeAmount1ForLiquidity calculates the required amount of token1 for a given liquidity level +// within a specified price range (represented by sqrt ratios). +// +// This function determines the amount of token1 needed to provide liquidity between the +// lower (sqrtRatioAX96) and upper (sqrtRatioBX96) price bounds. The calculation is performed +// in Q64.96 fixed-point format, which is standard for many liquidity calculations. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the price range as a square root ratio in Q64.96 format (*u256.Uint). +// - sqrtRatioBX96: The upper bound of the price range as a square root ratio in Q64.96 format (*u256.Uint). +// - liquidity: The liquidity amount to be used in the calculation (*u256.Uint). +// +// Returns: +// - *u256.Uint: The amount of token1 required to achieve the specified liquidity level. +// +// Notes: +// - This function handles edge cases where the liquidity is zero or when sqrtRatioAX96 equals sqrtRatioBX96 +// to prevent division by zero. +// - The calculation assumes sqrtRatioAX96 is always less than or equal to sqrtRatioBX96 after the initial +// ascending order sorting. +func computeAmount1ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint) *u256.Uint { + sqrtRatioAX96, sqrtRatioBX96 = toAscendingOrder(sqrtRatioAX96, sqrtRatioBX96) + if liquidity.IsZero() || sqrtRatioAX96.Eq(sqrtRatioBX96) { + return u256.Zero() + } + + diff := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) + res := u256.MulDiv(liquidity, diff, u256.MustFromDecimal(consts.Q96)) + + return res +} + +// GetAmountsForLiquidity calculates the amounts of token0 and token1 required +// to provide a specified liquidity within a price range. +// +// This function determines the quantities of token0 and token1 necessary to achieve +// a given liquidity level, depending on the current price (sqrtRatioX96) and the +// bounds of the price range (sqrtRatioAX96 and sqrtRatioBX96). The function returns +// the calculated amounts of token0 and token1 as strings. +// +// If the current price is below the lower bound of the price range, only token0 is required. +// If the current price is above the upper bound, only token1 is required. When the +// price is within the range, both token0 and token1 are calculated. +// +// Parameters: +// - sqrtRatioX96: The current price represented as a square root ratio in Q64.96 format (*u256.Uint). +// - sqrtRatioAX96: The lower bound of the price range as a square root ratio in Q64.96 format (*u256.Uint). +// - sqrtRatioBX96: The upper bound of the price range as a square root ratio in Q64.96 format (*u256.Uint). +// - liquidity: The amount of liquidity to be provided (*u256.Uint). +// +// Returns: +// - string: The calculated amount of token0 required to achieve the specified liquidity. +// - string: The calculated amount of token1 required to achieve the specified liquidity. +// +// Notes: +// - If liquidity is zero, the function returns "0" for both token0 and token1. +// - The function guarantees that sqrtRatioAX96 is always the lower bound and +// sqrtRatioBX96 is the upper bound by calling toAscendingOrder(). +// - Edge cases where the current price is exactly on the bounds are handled without division by zero. +// +// Example: +// ``` +// amount0, amount1 := GetAmountsForLiquidity( +// +// u256.MustFromDecimal("79228162514264337593543950336"), // sqrtRatioX96 (1.0 in Q64.96) +// u256.MustFromDecimal("39614081257132168796771975168"), // sqrtRatioAX96 (0.5 in Q64.96) +// u256.MustFromDecimal("158456325028528675187087900672"), // sqrtRatioBX96 (2.0 in Q64.96) +// u256.MustFromDecimal("1000000"), // Liquidity +// +// ) +// fmt.Println("Token0:", amount0, "Token1:", amount1) +// // Example output: Token0: 500000, Token1: 250000 +// ``` +func GetAmountsForLiquidity(sqrtRatioX96, sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint) (string, string) { + if liquidity.IsZero() { + return "0", "0" + } + + sqrtRatioAX96, sqrtRatioBX96 = toAscendingOrder(sqrtRatioAX96, sqrtRatioBX96) + + amount0 := u256.Zero() + amount1 := u256.Zero() + + if sqrtRatioX96.Lte(sqrtRatioAX96) { + amount0 = computeAmount0ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity) + } else if sqrtRatioX96.Lt(sqrtRatioBX96) { + amount0 = computeAmount0ForLiquidity(sqrtRatioX96, sqrtRatioBX96, liquidity) + amount1 = computeAmount1ForLiquidity(sqrtRatioAX96, sqrtRatioX96, liquidity) + } else { + amount1 = computeAmount1ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity) + } + + return amount0.ToString(), amount1.ToString() +} diff --git a/contract/r/gnoswap/common/liquidity_amounts_test.gno b/contract/r/gnoswap/common/liquidity_amounts_test.gno new file mode 100644 index 000000000..83b73ccbd --- /dev/null +++ b/contract/r/gnoswap/common/liquidity_amounts_test.gno @@ -0,0 +1,506 @@ +package common + +import ( + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestToAscendingOrder(t *testing.T) { + tests := []struct { + name string + a *u256.Uint + b *u256.Uint + expectedA string + expectedB string + }{ + { + name: "Ascending order - a < b", + a: u256.MustFromDecimal("10"), + b: u256.MustFromDecimal("20"), + expectedA: "10", + expectedB: "20", + }, + { + name: "Descending order - a > b", + a: u256.MustFromDecimal("50"), + b: u256.MustFromDecimal("30"), + expectedA: "30", + expectedB: "50", + }, + { + name: "Equal values - a == b", + a: u256.MustFromDecimal("100"), + b: u256.MustFromDecimal("100"), + expectedA: "100", + expectedB: "100", + }, + { + name: "Large numbers", + a: u256.MustFromDecimal("340282366920938463463374607431768211455"), // 2^128 - 1 + b: u256.MustFromDecimal("170141183460469231731687303715884105727"), // 2^127 - 1 + expectedA: "170141183460469231731687303715884105727", + expectedB: "340282366920938463463374607431768211455", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + min, max := toAscendingOrder(tt.a, tt.b) + + if min.ToString() != tt.expectedA { + t.Errorf("Expected min to be %s, got %s", tt.expectedA, min.ToString()) + } + + if max.ToString() != tt.expectedB { + t.Errorf("Expected max to be %s, got %s", tt.expectedB, max.ToString()) + } + }) + } +} + +func TestSafeConvertToUint128(t *testing.T) { + tests := []struct { + name string + input *u256.Uint + expected *u256.Uint + shouldPanic bool + }{ + { + name: "Valid uint128 value", + input: u256.MustFromDecimal("340282366920938463463374607431768211455"), // MAX_UINT128 + expected: u256.MustFromDecimal("340282366920938463463374607431768211455"), + shouldPanic: false, + }, + { + name: "Value exceeding uint128 range", + input: u256.MustFromDecimal("340282366920938463463374607431768211456"), // MAX_UINT128 + 1 + shouldPanic: true, + }, + { + name: "Zero value", + input: u256.Zero(), + expected: u256.Zero(), + shouldPanic: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic but got none") + } + }() + safeConvertToUint128(tt.input) + } else { + got := safeConvertToUint128(tt.input) + if !got.Eq(tt.expected) { + t.Errorf("Expected %s, got %s", tt.expected.ToString(), got.ToString()) + } + } + }) + } +} + +func TestComputeLiquidityForAmount0(t *testing.T) { + testCases := []struct { + name string + sqrtRatioA string + sqrtRatioB string + amount0 string + expected string + expectPanic bool + }{ + { + name: "Basic liquidity calculation", + sqrtRatioA: "79228162514264337593543950336", // sqrt(1) << 96 + sqrtRatioB: "158456325028528675187087900672", // sqrt(4) << 96 + amount0: "1000000", + expected: "2000000", // Expected liquidity + }, + { + name: "No liquidity (zero amount)", + sqrtRatioA: "79228162514264337593543950336", + sqrtRatioB: "158456325028528675187087900672", + amount0: "0", + expected: "0", + }, + { + name: "Liquidity overflow (exceeds uint128)", + sqrtRatioA: "158456325028528675187087900672", + sqrtRatioB: "316912650057057350374175801344", + amount0: "340282366920938463463374607431768211456", // Exceeds uint128 + expectPanic: true, + }, + { + name: "Zero liquidity with equal ratios", + sqrtRatioA: "79228162514264337593543950336", + sqrtRatioB: "79228162514264337593543950336", + amount0: "1000000", + expected: "0", + expectPanic: true, + }, + { + name: "Panic with identical ticks", + sqrtRatioA: "79228162514264337593543950336", + sqrtRatioB: "79228162514264337593543950336", + amount0: "1000000", + expectPanic: true, + }, + { + name: "Large liquidity calculation", + sqrtRatioA: "79228162514264337593543950336", + sqrtRatioB: "158456325028528675187087900672", + amount0: "1000000000", + expected: "2000000000", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tc.expectPanic { + t.Errorf("Unexpected panic for test case: %s", tc.name) + } + } + }() + + sqrtRatioA := u256.MustFromDecimal(tc.sqrtRatioA) + sqrtRatioB := u256.MustFromDecimal(tc.sqrtRatioB) + amount0 := u256.MustFromDecimal(tc.amount0) + + result := computeLiquidityForAmount0(sqrtRatioA, sqrtRatioB, amount0) + if !tc.expectPanic { + if result.ToString() != tc.expected { + t.Errorf("Expected %s but got %s", tc.expected, result.ToString()) + } + } + }) + } +} + +func TestComputeLiquidityForAmount1(t *testing.T) { + q96 := u256.MustFromDecimal(consts.Q96) + amount1 := u256.MustFromDecimal("1000000") + + t.Run("Basic liquidity calculation", func(t *testing.T) { + sqrtRatioAX96 := q96 // 2^96 (1 in Q96) + sqrtRatioBX96 := new(u256.Uint).Mul(q96, u256.MustFromDecimal("4")) // 4^96 (4 in Q96) + + expected := u256.MustFromDecimal("333333") // Expected liquidity + result := computeLiquidityForAmount1(sqrtRatioAX96, sqrtRatioBX96, amount1) + + if !result.Eq(expected) { + t.Errorf("Expected %s but got %s", expected.ToString(), result.ToString()) + } + }) + + t.Run("Zero liquidity with equal ratios", func(t *testing.T) { + sqrtRatioAX96 := q96 // 2^96 (1 in Q96) + sqrtRatioBX96 := q96 // Same as lower tick + + uassert.PanicsWithMessage(t, + "[GNOSWAP-COMMON-010] identical ticks || sqrtRatioAX96 (79228162514264337593543950336) and sqrtRatioBX96 (79228162514264337593543950336) are identical", + func() { + _ = computeLiquidityForAmount1(sqrtRatioAX96, sqrtRatioBX96, amount1) + }) + }) + + t.Run("Large liquidity calculation", func(t *testing.T) { + sqrtRatioAX96 := q96 // 1x + sqrtRatioBX96 := new(u256.Uint).Mul(q96, u256.NewUint(16)) // 16x + largeAmount := u256.MustFromDecimal("1000000000") + + expected := u256.MustFromDecimal("66666666") // 1B / 16 = 62.5M + result := computeLiquidityForAmount1(sqrtRatioAX96, sqrtRatioBX96, largeAmount) + + if !result.Eq(expected) { + t.Errorf("Expected %s but got %s", expected.ToString(), result.ToString()) + } + }) +} + +func TestGetLiquidityForAmounts(t *testing.T) { + q96 := u256.MustFromDecimal(consts.Q96) + + tests := []struct { + name string + sqrtRatioX96 string + sqrtRatioAX96 string + sqrtRatioBX96 string + amount0 string + amount1 string + expected string + }{ + { + name: "Basic Liquidity Calculation - Token0 Dominant", + sqrtRatioX96: q96.ToString(), + sqrtRatioAX96: (new(u256.Uint).Mul(q96, u256.NewUint(4))).ToString(), + sqrtRatioBX96: (new(u256.Uint).Mul(q96, u256.NewUint(16))).ToString(), + amount0: "1000000", + amount1: "1000000", + expected: "5333333", + }, + { + name: "Within Range - Both Token0 and Token1", + sqrtRatioX96: (new(u256.Uint).Mul(q96, u256.NewUint(2))).ToString(), + sqrtRatioAX96: (new(u256.Uint).Mul(q96, u256.NewUint(4))).ToString(), + sqrtRatioBX96: (new(u256.Uint).Mul(q96, u256.NewUint(16))).ToString(), + amount0: "2000000", + amount1: "3000000", + expected: "10666666", + }, + { + name: "Token1 Dominant - Price Above Upper Bound", + sqrtRatioX96: (new(u256.Uint).Mul(q96, u256.NewUint(20))).ToString(), + sqrtRatioAX96: (new(u256.Uint).Mul(q96, u256.NewUint(4))).ToString(), + sqrtRatioBX96: (new(u256.Uint).Mul(q96, u256.NewUint(16))).ToString(), + amount0: "1500000", + amount1: "3000000", + expected: "250000", + }, + { + name: "Edge Case - sqrtRatioX96 = Lower Bound", + sqrtRatioX96: (new(u256.Uint).Mul(q96, u256.NewUint(4))).ToString(), + sqrtRatioAX96: (new(u256.Uint).Mul(q96, u256.NewUint(4))).ToString(), + sqrtRatioBX96: (new(u256.Uint).Mul(q96, u256.NewUint(16))).ToString(), + amount0: "500000", + amount1: "500000", + expected: "2666666", + }, + { + name: "Edge Case - sqrtRatioX96 = Upper Bound", + sqrtRatioX96: (new(u256.Uint).Mul(q96, u256.NewUint(16))).ToString(), + sqrtRatioAX96: (new(u256.Uint).Mul(q96, u256.NewUint(4))).ToString(), + sqrtRatioBX96: (new(u256.Uint).Mul(q96, u256.NewUint(16))).ToString(), + amount0: "1000000", + amount1: "1000000", + expected: "83333", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sqrtRatioX96 := u256.MustFromDecimal(tc.sqrtRatioX96) + sqrtRatioAX96 := u256.MustFromDecimal(tc.sqrtRatioAX96) + sqrtRatioBX96 := u256.MustFromDecimal(tc.sqrtRatioBX96) + amount0 := u256.MustFromDecimal(tc.amount0) + amount1 := u256.MustFromDecimal(tc.amount1) + + result := GetLiquidityForAmounts(sqrtRatioX96, sqrtRatioAX96, sqrtRatioBX96, amount0, amount1) + expected := u256.MustFromDecimal(tc.expected) + + uassert.Equal(t, expected.ToString(), result.ToString()) + }) + } +} + +func TestComputeAmount0ForLiquidity(t *testing.T) { + q96 := u256.MustFromDecimal("79228162514264337593543950336") // 2^96 + + tests := []struct { + name string + sqrtRatioAX96 string + sqrtRatioBX96 string + liquidity string + expected string + }{ + { + name: "Basic Case - Small Range", + sqrtRatioAX96: new(u256.Uint).Mul(q96, u256.NewUint(4)).ToString(), + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(8)).ToString(), + liquidity: "1000000", + expected: "125000", + }, + { + name: "Large Liquidity - Wide Range", + sqrtRatioAX96: new(u256.Uint).Mul(q96, u256.NewUint(2)).ToString(), + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(16)).ToString(), + liquidity: "5000000000", + expected: "2187500000", + }, + { + name: "Edge Case - Equal Bounds", + sqrtRatioAX96: new(u256.Uint).Mul(q96, u256.NewUint(8)).ToString(), + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(8)).ToString(), + liquidity: "1000000", + expected: "0", + }, + { + name: "Minimum Liquidity", + sqrtRatioAX96: new(u256.Uint).Mul(q96, u256.NewUint(5)).ToString(), + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(10)).ToString(), + liquidity: "1", + expected: "0", + }, + { + name: "Max Liquidity", + sqrtRatioAX96: new(u256.Uint).Mul(q96, u256.NewUint(1)).ToString(), + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(32)).ToString(), + liquidity: "1000000000000000000", + expected: "968750000000000000", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sqrtRatioAX96 := u256.MustFromDecimal(tc.sqrtRatioAX96) + sqrtRatioBX96 := u256.MustFromDecimal(tc.sqrtRatioBX96) + liquidity := u256.MustFromDecimal(tc.liquidity) + + result := computeAmount0ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity) + expected := u256.MustFromDecimal(tc.expected) + + if result.ToString() != expected.ToString() { + t.Errorf("expected %s but got %s", expected.ToString(), result.ToString()) + } + }) + } +} + +func TestComputeAmount1ForLiquidity(t *testing.T) { + q96 := u256.MustFromDecimal(consts.Q96) // 2^96 = 79228162514264337593543950336 + + tests := []struct { + name string + sqrtRatioAX96 *u256.Uint + sqrtRatioBX96 *u256.Uint + liquidity *u256.Uint + expectedAmount string + }{ + { + name: "Basic Case - Small Liquidity", + sqrtRatioAX96: q96, + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(4)), // sqrtRatioBX96 = 4 * Q96 + liquidity: u256.NewUint(1000000000), + expectedAmount: "3000000000", // (4-1)*liquidity = 3 * 10^9 + }, + { + name: "Edge Case - Equal Ratios", + sqrtRatioAX96: q96, + sqrtRatioBX96: q96, + liquidity: u256.NewUint(1000000), + expectedAmount: "0", + }, + { + name: "Zero Liquidity", + sqrtRatioAX96: q96, + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(2)), + liquidity: u256.Zero(), + expectedAmount: "0", + }, + { + name: "Large Liquidity", + sqrtRatioAX96: q96, + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(16)), // sqrtRatioBX96 = 16 * Q96 + liquidity: u256.NewUint(1000000000000000000), // 1e18 liquidity + expectedAmount: "15000000000000000000", // (16-1) * 1e18 = 15 * 1e18 + }, + { + name: "Descending Ratios (Order Correction)", + sqrtRatioAX96: new(u256.Uint).Mul(q96, u256.NewUint(8)), + sqrtRatioBX96: q96, + liquidity: u256.NewUint(500000), + expectedAmount: "3500000", // (8-1)*500000 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := computeAmount1ForLiquidity(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity) + uassert.Equal(t, tt.expectedAmount, result.ToString(), ufmt.Sprintf("expected %s but got %s", tt.expectedAmount, result.ToString())) + }) + } +} + +func TestGetAmountsForLiquidity(t *testing.T) { + q96 := u256.MustFromDecimal(consts.Q96) // 2^96 = 79228162514264337593543950336 + + tests := []struct { + name string + sqrtRatioX96 *u256.Uint + sqrtRatioAX96 *u256.Uint + sqrtRatioBX96 *u256.Uint + liquidity *u256.Uint + expectedAmount0 string + expectedAmount1 string + }{ + { + name: "Basic Case - Within Range", + sqrtRatioX96: new(u256.Uint).Mul(q96, u256.NewUint(2)), // Current price at 2 * Q96 + sqrtRatioAX96: q96, // Lower bound at 1 * Q96 + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(4)), // Upper bound at 4 * Q96 + liquidity: u256.NewUint(1000000), + expectedAmount0: "250000", + expectedAmount1: "1000000", + }, + { + name: "Edge Case - At Lower Bound (sqrtRatioX96 == sqrtRatioAX96)", + sqrtRatioX96: q96, + sqrtRatioAX96: q96, + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(8)), + liquidity: u256.NewUint(1000000), + expectedAmount0: "875000", + expectedAmount1: "0", + }, + { + name: "Edge Case - At Upper Bound (sqrtRatioX96 == sqrtRatioBX96)", + sqrtRatioX96: new(u256.Uint).Mul(q96, u256.NewUint(8)), + sqrtRatioAX96: q96, + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(8)), + liquidity: u256.NewUint(1000000), + expectedAmount0: "0", + expectedAmount1: "7000000", + }, + { + name: "Out of Range - Below Lower Bound", + sqrtRatioX96: new(u256.Uint).Div(q96, u256.NewUint(2)), // Current price at 0.5 * Q96 + sqrtRatioAX96: q96, + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(8)), + liquidity: u256.NewUint(500000), + expectedAmount0: "437500", + expectedAmount1: "0", + }, + { + name: "Out of Range - Above Upper Bound", + sqrtRatioX96: new(u256.Uint).Mul(q96, u256.NewUint(10)), // Current price at 10 * Q96 + sqrtRatioAX96: q96, + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(8)), + liquidity: u256.NewUint(2000000), + expectedAmount0: "0", + expectedAmount1: "14000000", + }, + { + name: "Zero Liquidity", + sqrtRatioX96: q96, + sqrtRatioAX96: q96, + sqrtRatioBX96: new(u256.Uint).Mul(q96, u256.NewUint(16)), + liquidity: u256.Zero(), + expectedAmount0: "0", + expectedAmount1: "0", + }, + { + name: "Descending Ratios (Order Correction)", + sqrtRatioX96: q96, + sqrtRatioAX96: new(u256.Uint).Mul(q96, u256.NewUint(8)), + sqrtRatioBX96: q96, + liquidity: u256.NewUint(1000000), + expectedAmount0: "875000", + expectedAmount1: "0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + amount0, amount1 := GetAmountsForLiquidity(tt.sqrtRatioX96, tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity) + uassert.Equal(t, tt.expectedAmount0, amount0, ufmt.Sprintf("expected %s but got %s for amount0", tt.expectedAmount0, amount0)) + uassert.Equal(t, tt.expectedAmount1, amount1, ufmt.Sprintf("expected %s but got %s for amount1", tt.expectedAmount1, amount1)) + }) + } +} diff --git a/contract/r/gnoswap/common/math.gno b/contract/r/gnoswap/common/math.gno new file mode 100644 index 000000000..4a03fb4be --- /dev/null +++ b/contract/r/gnoswap/common/math.gno @@ -0,0 +1,100 @@ +package common + +import ( + "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" +) + +// I64Min returns the minimum of two int64 values +func I64Min(x, y int64) int64 { + if x < y { + return x + } + return y +} + +// I64Max returns the maximum of two int64 values +func I64Max(x, y int64) int64 { + if x > y { + return x + } + return y +} + +// U64Min returns the minimum of two uint64 values +func U64Min(x, y uint64) uint64 { + if x < y { + return x + } + return y +} + +// U64Max returns the maximum of two uint64 values +func U64Max(x, y uint64) uint64 { + if x > y { + return x + } + return y +} + +// I256Min returns the minimum of two Int256 values +func I256Min(x, y *i256.Int) *i256.Int { + if x.Cmp(y) < 0 { + return x + } + return y +} + +// I256Max returns the maximum of two Int256 values +func I256Max(x, y *i256.Int) *i256.Int { + if x.Cmp(y) > 0 { + return x + } + return y +} + +// U256Min returns the minimum of two Uint256 values +func U256Min(x, y *u256.Uint) *u256.Uint { + if x.Cmp(y) < 0 { + return x + } + return y +} + +// U256Max returns the maximum of two Uint256 values +func U256Max(x, y *u256.Uint) *u256.Uint { + if x.Cmp(y) > 0 { + return x + } + return y +} + +// SafeConvertUint256ToInt256 converts a uint256.Uint to int256.Int and returns it. +// If the value is greater than the maximum int256 value, it panics. +func SafeConvertUint256ToInt256(x *u256.Uint) *i256.Int { + if x.Gt(u256.MustFromDecimal(consts.MAX_INT256)) { + panic(newErrorWithDetail( + errOverflow, + ufmt.Sprintf("can not convert %s to int256", x.ToString()), + )) + } + return i256.FromUint256(x) +} + +// SafeConvertUint256ToUint64 converts a uint256.Uint to uint64 and returns it. +// If the value is greater than the maximum uint64 value, it panics. +func SafeConvertUint256ToUint64(x *u256.Uint) uint64 { + value, overflow := x.Uint64WithOverflow() + if overflow { + panic(newErrorWithDetail( + errOverflow, + ufmt.Sprintf("can not convert %s to uint64", x.ToString()), + )) + } + + return value +} diff --git a/contract/r/gnoswap/common/math_test.gno b/contract/r/gnoswap/common/math_test.gno new file mode 100644 index 000000000..8c9088a96 --- /dev/null +++ b/contract/r/gnoswap/common/math_test.gno @@ -0,0 +1,234 @@ +package common + +import ( + "testing" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestI64Min(t *testing.T) { + tests := []struct { + x, y, want int64 + }{ + {1, 2, 1}, + {-1, 1, -1}, + {5, 5, 5}, + {-10, -5, -10}, + {9223372036854775807, 0, 0}, // Max int64 + } + + for _, tt := range tests { + got := I64Min(tt.x, tt.y) + if got != tt.want { + t.Errorf("I64Min(%v, %v) = %v, want %v", tt.x, tt.y, got, tt.want) + } + } +} + +func TestI64Max(t *testing.T) { + tests := []struct { + x, y, want int64 + }{ + {1, 2, 2}, + {-1, 1, 1}, + {5, 5, 5}, + {-10, -5, -5}, + {9223372036854775807, 0, 9223372036854775807}, // Max int64 + } + + for _, tt := range tests { + got := I64Max(tt.x, tt.y) + if got != tt.want { + t.Errorf("I64Max(%v, %v) = %v, want %v", tt.x, tt.y, got, tt.want) + } + } +} + +func TestU64Min(t *testing.T) { + tests := []struct { + x, y, want uint64 + }{ + {1, 2, 1}, + {0, 1, 0}, + {5, 5, 5}, + {10, 5, 5}, + {18446744073709551615, 0, 0}, // Max uint64 + } + + for _, tt := range tests { + got := U64Min(tt.x, tt.y) + if got != tt.want { + t.Errorf("U64Min(%v, %v) = %v, want %v", tt.x, tt.y, got, tt.want) + } + } +} + +func TestU64Max(t *testing.T) { + tests := []struct { + x, y, want uint64 + }{ + {1, 2, 2}, + {0, 1, 1}, + {5, 5, 5}, + {10, 5, 10}, + {18446744073709551615, 0, 18446744073709551615}, // Max uint64 + } + + for _, tt := range tests { + got := U64Max(tt.x, tt.y) + if got != tt.want { + t.Errorf("U64Max(%v, %v) = %v, want %v", tt.x, tt.y, got, tt.want) + } + } +} + +func TestI256Min(t *testing.T) { + tests := []struct { + x, y string // hex strings for creating Int + want string + }{ + {"1", "2", "1"}, + {"-1", "1", "-1"}, + {"5", "5", "5"}, + {"-10", "-5", "-10"}, + } + + for _, tt := range tests { + x, _ := i256.FromDecimal(tt.x) + y, _ := i256.FromDecimal(tt.y) + want, _ := i256.FromDecimal(tt.want) + got := I256Min(x, y) + if got.Cmp(want) != 0 { + t.Errorf("I256Min(%v, %v) = %v, want %v", tt.x, tt.y, got, want) + } + } +} + +func TestI256Max(t *testing.T) { + tests := []struct { + x, y string // hex strings for creating Int + want string + }{ + {"1", "2", "2"}, + {"-1", "1", "1"}, + {"5", "5", "5"}, + {"-10", "-5", "-5"}, + } + + for _, tt := range tests { + x, _ := i256.FromDecimal(tt.x) + y, _ := i256.FromDecimal(tt.y) + want, _ := i256.FromDecimal(tt.want) + got := I256Max(x, y) + if got.Cmp(want) != 0 { + t.Errorf("I256Max(%v, %v) = %v, want %v", tt.x, tt.y, got, want) + } + } +} + +func TestU256Min(t *testing.T) { + tests := []struct { + x, y string // decimal strings for creating Uint + want string + }{ + {"1", "2", "1"}, + {"0", "1", "0"}, + {"5", "5", "5"}, + {"10", "5", "5"}, + } + + for _, tt := range tests { + x, _ := u256.FromDecimal(tt.x) + y, _ := u256.FromDecimal(tt.y) + want, _ := u256.FromDecimal(tt.want) + got := U256Min(x, y) + if got.Cmp(want) != 0 { + t.Errorf("U256Min(%v, %v) = %v, want %v", tt.x, tt.y, got, want) + } + } +} + +func TestU256Max(t *testing.T) { + tests := []struct { + x, y string // decimal strings for creating Uint + want string + }{ + {"1", "2", "2"}, + {"0", "1", "1"}, + {"5", "5", "5"}, + {"10", "5", "10"}, + } + + for _, tt := range tests { + x, _ := u256.FromDecimal(tt.x) + y, _ := u256.FromDecimal(tt.y) + want, _ := u256.FromDecimal(tt.want) + got := U256Max(x, y) + if got.Cmp(want) != 0 { + t.Errorf("U256Max(%v, %v) = %v, want %v", tt.x, tt.y, got, want) + } + } +} + +func TestSafeConvertUint256ToInt256(t *testing.T) { + tests := []struct { + x, want string + shouldPanic bool + }{ + {"0", "0", false}, + {"1", "1", false}, + // max int256 + {"57896044618658097711785492504343953926634992332820282019728792003956564819968", "57896044618658097711785492504343953926634992332820282019728792003956564819968", false}, + // max int256 + 1 (overflow) + {"57896044618658097711785492504343953926634992332820282019728792003956564819969", "", true}, + } + + for _, tt := range tests { + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("SafeConvertUint256ToInt256(%v) did not panic", tt.x) + } + }() + } + + x := u256.MustFromDecimal(tt.x) + got := SafeConvertUint256ToInt256(x) + want := i256.MustFromDecimal(tt.want) + if got.Cmp(want) != 0 { + t.Errorf("SafeConvertUint256ToInt256(%v) = %v, want %v", tt.x, got, want) + } + } +} + +func TestSafeConvertUint256ToUint64(t *testing.T) { + tests := []struct { + x string + want uint64 + shouldPanic bool + }{ + {"0", 0, false}, + {"1", 1, false}, + // max uint64 + {"18446744073709551615", 18446744073709551615, false}, + // max uint64 + 1 (overflow) + {"18446744073709551616", 0, true}, + } + + for _, tt := range tests { + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("SafeConvertUint256ToUint64(%v) did not panic", tt.x) + } + }() + } + + x := u256.MustFromDecimal(tt.x) + got := SafeConvertUint256ToUint64(x) + if got != tt.want { + t.Errorf("SafeConvertUint256ToUint64(%v) = %v, want %v", tt.x, got, tt.want) + } + } +} diff --git a/contract/r/gnoswap/common/strings.gno b/contract/r/gnoswap/common/strings.gno new file mode 100644 index 000000000..c14b45303 --- /dev/null +++ b/contract/r/gnoswap/common/strings.gno @@ -0,0 +1,59 @@ +package common + +import ( + "strconv" + "strings" + + "gno.land/p/demo/ufmt" +) + +func Split(input string, sep string, length int) ([]string, error) { + result := strings.Split(input, sep) + if len(result) != length { + return nil, ufmt.Errorf("invalid length: %d", len(result)) + } + + return result, nil +} + +// EncodeUint converts an uint64 number into a zero-padded 20-character string. +// +// Parameters: +// - num (uint64): The number to encode. +// +// Returns: +// - string: A zero-padded string representation of the number. +// +// Example: +// Input: 12345 +// Output: "00000000000000012345" +func EncodeUint(num uint64) string { + // Convert the value to a decimal string. + s := strconv.FormatUint(num, 10) + + // Zero-pad to a total length of 20 characters. + zerosNeeded := 20 - len(s) + return strings.Repeat("0", zerosNeeded) + s +} + +// DecodeUint converts a zero-padded string back into a uint64 number. +// +// Parameters: +// - s (string): The zero-padded string. +// +// Returns: +// - uint64: The decoded number. +// +// Panics: +// - If the string cannot be parsed into a uint64. +// +// Example: +// Input: "00000000000000012345" +// Output: 12345 +func DecodeUint(s string) uint64 { + num, err := strconv.ParseUint(s, 10, 64) + if err != nil { + panic(err) + } + return num +} diff --git a/contract/r/gnoswap/common/tick_math.gno b/contract/r/gnoswap/common/tick_math.gno new file mode 100644 index 000000000..646573d77 --- /dev/null +++ b/contract/r/gnoswap/common/tick_math.gno @@ -0,0 +1,378 @@ +package common + +import ( + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +const ( + MAX_UINT8 string = "255" + MAX_UINT16 string = "65535" + MAX_UINT32 string = "4294967295" + MAX_UINT64 string = "18446744073709551615" + MAX_UINT128 string = "340282366920938463463374607431768211455" + MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" + MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819967" + + Q64 string = "18446744073709551616" // 2 ** 64 + Q96 string = "79228162514264337593543950336" // 2 ** 96 + Q128 string = "340282366920938463463374607431768211456" // 2 ** 128 +) + +var ( + tickRatioTree = NewTickRatioTree() + binaryLogTree = NewBinaryLogTree() + + shift1By32Left = u256.MustFromDecimal("4294967296") // (1 << 32) + maxTick = int32(887272) +) + +// TreeValue wraps u256.Uint to implement custom value type for avl.Tree +type TreeValue struct { + value *u256.Uint +} + +func (tv TreeValue) String() string { + return tv.value.ToString() +} + +// TickRatioTree manages tick ratios +type TickRatioTree struct { + tree *avl.Tree +} + +// NewTickRatioTree initializes a new TickRatioTree with predefined values +func NewTickRatioTree() *TickRatioTree { + tree := avl.NewTree() + + ratios := []struct { + key int32 + value string + }{ + {0x1, "340265354078544963557816517032075149313"}, + {0x2, "340248342086729790484326174814286782778"}, + {0x4, "340214320654664324051920982716015181260"}, + {0x8, "340146287995602323631171512101879684304"}, + {0x10, "340010263488231146823593991679159461444"}, + {0x20, "339738377640345403697157401104375502016"}, + {0x40, "339195258003219555707034227454543997025"}, + {0x80, "338111622100601834656805679988414885971"}, + {0x100, "335954724994790223023589805789778977700"}, + {0x200, "331682121138379247127172139078559817300"}, + {0x400, "323299236684853023288211250268160618739"}, + {0x800, "307163716377032989948697243942600083929"}, + {0x1000, "277268403626896220162999269216087595045"}, + {0x2000, "225923453940442621947126027127485391333"}, + {0x4000, "149997214084966997727330242082538205943"}, + {0x8000, "66119101136024775622716233608466517926"}, + {0x10000, "12847376061809297530290974190478138313"}, + {0x20000, "485053260817066172746253684029974020"}, + {0x40000, "691415978906521570653435304214168"}, + {0x80000, "1404880482679654955896180642"}, + } + + for _, ratio := range ratios { + tick := ufmt.Sprintf("%d", ratio.key) + value := TreeValue{u256.MustFromDecimal(ratio.value)} + tree.Set(tick, value) + } + + return &TickRatioTree{tree} +} + +// GetRatio retrieves the ratio for a given tick +func (t *TickRatioTree) GetRatio(key int32) (*u256.Uint, bool) { + strKey := ufmt.Sprintf("%d", key) + value, exists := t.tree.Get(strKey) + + if !exists { + return nil, false + } + + if tv, ok := value.(TreeValue); ok { + return tv.value, true + } + + return nil, false +} + +// BinaryLogTree manages binary log constants +type BinaryLogTree struct { + tree *avl.Tree +} + +// NewBinaryLogTree initializes a new BinaryLogTree +func NewBinaryLogTree() *BinaryLogTree { + tree := avl.NewTree() + + logs := [8]string{ + "0", + "3", + "15", + "255", + "65535", + "4294967295", + "18446744073709551615", + "340282366920938463463374607431768211455", + } + + for i, value := range logs { + key := ufmt.Sprintf("%d", i) + tree.Set(key, TreeValue{u256.MustFromDecimal(value)}) + } + + return &BinaryLogTree{tree} +} + +// GetLog retrieves the binary log constant at given index +func (t *BinaryLogTree) GetLog(idx int) (*u256.Uint, bool) { + strKey := ufmt.Sprintf("%d", idx) + value, exists := t.tree.Get(strKey) + + if !exists { + return nil, false + } + + if tv, ok := value.(TreeValue); ok { + return tv.value, true + } + + return nil, false +} + +// TickMathGetSqrtRatioAtTick calculates the square root price ratio for a given tick. +// +// This function computes the square root ratio (sqrt(price)) at a specific tick, +// using a precomputed mapping of ratios. The result is returned as a 160-bit +// fixed-point value (Q64.96 format). +// +// Parameters: +// - tick (int32): The tick index for which the square root ratio is calculated. +// +// Returns: +// - *u256.Uint: The square root price ratio at the given tick, represented as a 160-bit unsigned integer. +// +// Behavior: +// 1. Validates that the tick is within the acceptable range by asserting its absolute value. +// 2. Initializes the ratio based on whether the least significant bit of the tick is set, using a lookup table (`tickRatioMap`). +// 3. Iteratively adjusts the ratio by multiplying and right-shifting with precomputed values for each relevant bit set in the tick value. +// 4. If the tick is positive, the ratio is inverted by dividing a maximum uint256 value by the computed ratio. +// 5. The result is split into upper 128 bits and lower 32 bits for precision handling. +// 6. If the lower 32 bits are non-zero, the upper 128 bits are incremented by 1 to ensure rounding up. +// +// Example: +// - For a tick of `0`, the ratio represents `1.0000` in Q64.96 format. +// - For a tick of `-887272` (minimum tick), the ratio represents the smallest possible price. +// - For a tick of `887272` (maximum tick), the ratio represents the highest possible price. +// +// Panics: +// - If the absolute tick value exceeds the maximum allowed tick range. +// +// Notes: +// - The function relies on a precomputed map `tickRatioMap` to optimize calculations. +// - Handles rounding by adding 1 if the remainder of the division is non-zero. +func TickMathGetSqrtRatioAtTick(tick int32) *u256.Uint { // uint160 sqrtPriceX96 + absTick := abs(tick) + assertValidTickRange(absTick) + + ratio := u256.MustFromDecimal("340282366920938463463374607431768211456") // consts.Q128 + initialBit := int32(0x1) + + if val, exists := tickRatioTree.GetRatio(initialBit); exists && (absTick&initialBit) != 0 { + ratio = val + } + + calculateRatio := func(mask int32) *u256.Uint { + if value, exists := tickRatioTree.GetRatio(mask); exists && absTick&mask != 0 { + return new(u256.Uint).Rsh( + new(u256.Uint).Mul(ratio, value), + 128, + ) + } + return ratio + } + + for mask := int32(0x2); mask <= 0x80000; mask *= 2 { + ratio = calculateRatio(mask) + } + + if tick > 0 { + maxUint256 := u256.MustFromDecimal(MAX_UINT256) // consts.MAX_UINT256 + if ratio.IsZero() { + return u256.Zero() + } + ratio = new(u256.Uint).Div(maxUint256, ratio) + } + + upper128Bits := new(u256.Uint).Rsh(ratio, 32) // ratio >> 32 + lower32Bits := ratio.Mod(ratio, shift1By32Left) // ratio % (1 << 32) + + var roundUp *u256.Uint + if lower32Bits.IsZero() { + roundUp = u256.Zero() + } else { + roundUp = u256.One() + } + + return new(u256.Uint).Add(upper128Bits, roundUp) +} + +func TickMathGetTickAtSqrtRatio(sqrtPriceX96 *u256.Uint) int32 { + cond1 := sqrtPriceX96.Gte(u256.MustFromDecimal("4295128739")) // MIN_SQRT_RATIO + cond2 := sqrtPriceX96.Lt(u256.MustFromDecimal("1461446703485210103287273052203988822378723970342")) // MAX_SQRT_RATIO + if !(cond1 && cond2) { + panic(newErrorWithDetail( + errOutOfRange, + ufmt.Sprintf("sqrtPriceX96 is out of range, sqrtPriceX96: %s", sqrtPriceX96.ToString()), + )) + } + + ratio := new(u256.Uint).Lsh(sqrtPriceX96, 32) + + msb, adjustedRatio := findMSB(ratio) + adjustedRatio = adjustRatio(ratio, msb) + + log2 := calculateLog2(msb, adjustedRatio) + tick := getTickValue(log2, sqrtPriceX96) + + return tick +} + +// findMSB computes the MSB (most significant bit) of the given ratio. +func findMSB(ratio *u256.Uint) (*u256.Uint, *u256.Uint) { + msb := u256.Zero() + + calculateMSB := func(i int) (*u256.Uint, *u256.Uint) { + if logConst, exists := binaryLogTree.GetLog(i); exists { + f := new(u256.Uint).Lsh(gt(ratio, logConst), uint(i)) + msb = new(u256.Uint).Or(msb, f) + ratio = new(u256.Uint).Rsh(ratio, uint(f.Uint64())) + } + return msb, ratio + } + + for i := 7; i >= 1; i-- { + msb, ratio = calculateMSB(i) + } + + // handle the remaining bits + { + f := gt(ratio, u256.One()) // 0x1 + // msb = msb | f + msb = new(u256.Uint).Or(msb, f) + } + + return msb, ratio +} + +// adjustRatio adjusts the given ratio based on the MSB found. +// +// This adjustment ensures that the ratio falls within the specific range. +func adjustRatio(ratio, msb *u256.Uint) *u256.Uint { + if msb.Gte(u256.NewUint(128)) { + return new(u256.Uint).Rsh(ratio, uint(msb.Uint64()-127)) + } + + return new(u256.Uint).Lsh(ratio, uint(127-msb.Uint64())) +} + +// calculateLog2 calculates the binary logarith, of the adjusted ratio using a fixed-point arithmetic. +// +// This function iteratively squares the ratio and adjusts the result to compute the log base 2, which will determine the tick value. +func calculateLog2(msb, ratio *u256.Uint) *i256.Int { + _msb := i256.FromUint256(msb) + _128 := i256.NewInt(128) + + log_2 := i256.Zero().Sub(_msb, _128) + log_2 = log_2.Lsh(log_2, 64) + + for i := 63; i >= 51; i-- { + ratio = new(u256.Uint).Mul(ratio, ratio) + ratio = ratio.Rsh(ratio, 127) + + f := i256.FromUint256(new(u256.Uint).Rsh(ratio, 128)) + + // log_2 = log_2 | (f << i) + log_2 = i256.Zero().Or(log_2, i256.Zero().Lsh(f, uint(i))) + + // ratio = ratio >> uint64(f) + ratio = ratio.Rsh(ratio, uint(f.Uint64())) + } + + // handle the remaining bits + { + // ratio = ratio * ratio >> 127 + ratio = new(u256.Uint).Mul(ratio, ratio) + ratio = new(u256.Uint).Rsh(ratio, 127) + + f := i256.FromUint256(new(u256.Uint).Rsh(ratio, 128)) + + log_2 = i256.Zero().Or(log_2, i256.Zero().Lsh(f, 50)) + } + + return log_2 +} + +// getTickValue determines the tick value corresponding to a given sqrtPriveX96. +// +// It calculates the upper and lower bounds for each tick, and selects the appropriate tock value +// based on the given sqrtPriceX96. +func getTickValue(log2 *i256.Int, sqrtPriceX96 *u256.Uint) int32 { + // ref: https://github.com/Uniswap/v3-core/issues/500 + // 2^64 / log2 (√1.0001) = 255738958999603826347141 + log_sqrt10001 := i256.Zero().Mul(log2, i256.MustFromDecimal("255738958999603826347141")) + + // ref: https://ethereum.stackexchange.com/questions/113844/how-does-uniswap-v3s-logarithm-library-tickmath-sol-work/113912#113912 + // 0.010000497 x 2^128 = 3402992956809132418596140100660247210 + tickLow256 := i256.Zero().Sub(log_sqrt10001, i256.MustFromDecimal("3402992956809132418596140100660247210")) + tickLow256 = tickLow256.Rsh(tickLow256, 128) + tickLow := int32(tickLow256.Int64()) + + // ref: https://ethereum.stackexchange.com/questions/113844/how-does-uniswap-v3s-logarithm-library-tickmath-sol-work/113912#113912 + // 0.856 x 2^128 = 291339464771989622907027621153398088495 + tickHi256 := i256.Zero().Add(log_sqrt10001, i256.MustFromDecimal("291339464771989622907027621153398088495")) + tickHi256 = tickHi256.Rsh(tickHi256, 128) + tickHi := int32(tickHi256.Int64()) + + var tick int32 + if tickLow == tickHi { + tick = tickLow + } else if TickMathGetSqrtRatioAtTick(tickHi).Lte(sqrtPriceX96) { + tick = tickHi + } else { + tick = tickLow + } + + return tick +} + +func gt(x, y *u256.Uint) *u256.Uint { + if x.Gt(y) { + return u256.One() + } + + return u256.Zero() +} + +// abs returns the absolute value of the given integer. +func abs(x int32) int32 { + if x < 0 { + return -x + } + + return x +} + +// assertValidTickRange validates that the absolute tick value is within the acceptable range. +func assertValidTickRange(absTick int32) { + if absTick > maxTick { + panic(newErrorWithDetail( + errOutOfRange, + ufmt.Sprintf("abs tick is out of range (larger than 887272), abs tick: %d", absTick), + )) + } +} diff --git a/contract/r/gnoswap/common/tick_math_test.gno b/contract/r/gnoswap/common/tick_math_test.gno new file mode 100644 index 000000000..bc3a70365 --- /dev/null +++ b/contract/r/gnoswap/common/tick_math_test.gno @@ -0,0 +1,209 @@ +package common + +import ( + "testing" + + "gno.land/p/demo/uassert" + + u256 "gno.land/p/gnoswap/uint256" +) + +var ( + MIN_TICK = int32(-887272) + MIN_SQRT_RATIO = "4295128739" + + MAX_TICK = int32(887272) + MAX_SQRT_RATIO = "1461446703485210103287273052203988822378723970342" +) + +func TestTickMathGetSqrtRatioAtTick(t *testing.T) { + t.Run("throws for too low", func(t *testing.T) { + tick := MIN_TICK - 1 + + uassert.PanicsWithMessage( + t, + "[GNOSWAP-COMMON-003] value out of range || abs tick is out of range (larger than 887272), abs tick: 887273", + func() { + TickMathGetSqrtRatioAtTick(tick) + }, + ) + }) + + t.Run("throws for too high", func(t *testing.T) { + tick := MAX_TICK + 1 + + uassert.PanicsWithMessage( + t, + "[GNOSWAP-COMMON-003] value out of range || abs tick is out of range (larger than 887272), abs tick: 887273", + func() { + TickMathGetSqrtRatioAtTick(tick) + }, + ) + }) + + t.Run("min tick", func(t *testing.T) { + tick := MIN_TICK + sqrtPriceX96 := TickMathGetSqrtRatioAtTick(tick) + uassert.Equal(t, "4295128739", sqrtPriceX96.ToString()) + }) + + t.Run("min tick + 1", func(t *testing.T) { + tick := MIN_TICK + 1 + sqrtPriceX96 := TickMathGetSqrtRatioAtTick(tick) + uassert.Equal(t, "4295343490", sqrtPriceX96.ToString()) + }) + + t.Run("min tick ratio is less than js implementation", func(t *testing.T) { + // encodePriceSqrt(1, BigNumber.from(2).pow(127)) = 6085630636 + sqrtPriceX96 := TickMathGetSqrtRatioAtTick(MIN_TICK) + uassert.False(t, sqrtPriceX96.Cmp(u256.MustFromDecimal("6085630636")) > 0, "should be less than 6085630636") + }) + + t.Run("max tick ratio is greater than js implementation", func(t *testing.T) { + // encodePriceSqrt(BigNumber.from(2).pow(127),1) = 1033437718471923706666374484006904511252097097914 + sqrtPriceX96 := TickMathGetSqrtRatioAtTick(MAX_TICK) + uassert.False(t, sqrtPriceX96.Cmp(u256.MustFromDecimal("1033437718471923706666374484006904511252097097914")) < 0, "should be greater than 1033437718471923706666374484006904511252097097914") + }) + + t.Run("max tick ratio is equal to js implementation", func(t *testing.T) { + sqrtPriceX96 := TickMathGetSqrtRatioAtTick(MAX_TICK) + uassert.Equal(t, "1461446703485210103287273052203988822378723970342", sqrtPriceX96.ToString()) + }) + + t.Run("multiple ticks", func(t *testing.T) { + absTicks := []int32{50, 100, 250, 500, 1000, 2500, 3000, 4000, 5000, 50000, 150000, 250000, 500000, 738203} + expectedResults := []string{ + "79030349367926598376800521322", + "79426470787362580746886972461", + "78833030112140176575862854579", + "79625275426524748796330556128", + "78244023372248365697264290337", + "80224679980005306637834519095", + "77272108795590369356373805297", + "81233731461783161732293370115", + "75364347830767020784054125655", + "83290069058676223003182343270", + "69919044979842180277688105136", + "89776708723587163891445672585", + "68192822843687888778582228483", + "92049301871182272007977902845", + "64867181785621769311890333195", + "96768528593268422080558758223", + "61703726247759831737814779831", + "101729702841318637793976746270", + "6504256538020985011912221507", + "965075977353221155028623082916", + "43836292794701720435367485", + "143194173941309278083010301478497", + "295440463448801648376846", + "21246587762933397357449903968194344", + "1101692437043807371", + "5697689776495288729098254600827762987878", + "7409801140451", + "847134979253254120489401328389043031315994541", + } + + for i, absTick := range absTicks { + for j, tick := range []int32{-absTick, absTick} { + result := TickMathGetSqrtRatioAtTick(tick) + uassert.Equal(t, expectedResults[i*2+j], result.ToString()) + } + } + }) + + t.Run("min_sqrt_ratio", func(t *testing.T) { + sqrtPriceX96 := TickMathGetSqrtRatioAtTick(MIN_TICK) + uassert.Equal(t, "4295128739", sqrtPriceX96.ToString()) + }) + + t.Run("max_sqrt_ratio", func(t *testing.T) { + sqrtPriceX96 := TickMathGetSqrtRatioAtTick(MAX_TICK) + uassert.Equal(t, MAX_SQRT_RATIO, sqrtPriceX96.ToString()) + }) + + t.Run("throws for too low sqrt_ratio", func(t *testing.T) { + sqrtPriceX96 := u256.MustFromDecimal(MIN_SQRT_RATIO) + sqrtPriceX96.Sub(sqrtPriceX96, u256.One()) + + uassert.PanicsWithMessage( + t, + "[GNOSWAP-COMMON-003] value out of range || sqrtPriceX96 is out of range, sqrtPriceX96: 4295128738", + func() { + TickMathGetTickAtSqrtRatio(sqrtPriceX96) + }, + ) + }) + + t.Run("throws for too high sqrt_ratio", func(t *testing.T) { + sqrtPriceX96 := u256.MustFromDecimal(MAX_SQRT_RATIO) + sqrtPriceX96.Add(sqrtPriceX96, u256.One()) + + uassert.PanicsWithMessage( + t, + "[GNOSWAP-COMMON-003] value out of range || sqrtPriceX96 is out of range, sqrtPriceX96: 1461446703485210103287273052203988822378723970343", + func() { + TickMathGetTickAtSqrtRatio(sqrtPriceX96) + }, + ) + }) + + t.Run("ratio of min tick", func(t *testing.T) { + sqrtPriceX96 := TickMathGetSqrtRatioAtTick(MIN_TICK) + uassert.Equal(t, MIN_SQRT_RATIO, sqrtPriceX96.ToString()) + }) + + t.Run("ratio of min tick + 1", func(t *testing.T) { + sqrtPriceX96 := u256.MustFromDecimal("4295343490") + uassert.Equal(t, MIN_TICK+1, TickMathGetTickAtSqrtRatio(sqrtPriceX96)) + }) + + t.Run("multiple sqrt_ratios", func(t *testing.T) { + ratios := []string{ + "4295128739", + "79228162514264337593543950336000000", + "79228162514264337593543950336000", + "9903520314283042199192993792", + "28011385487393069959365969113", + "56022770974786139918731938227", + "79228162514264337593543950336", + "112045541949572279837463876454", + "224091083899144559674927752909", + "633825300114114700748351602688", + "79228162514264337593543950", + "79228162514264337593543", + "1461446703485210103287273052203988822378723970341", + } + expectedResults := []int32{ + -887272, + 276324, + 138162, + -41591, + -20796, + -6932, + 0, // got -1, expected 0 + 6931, + 20795, + 41590, + -138163, + -276325, + 887271, + } + + for i, ratio := range ratios { + sqrtPriceX96 := u256.MustFromDecimal(ratio) + rst := TickMathGetTickAtSqrtRatio(sqrtPriceX96) + uassert.Equal(t, expectedResults[i], rst) + } + }) + + t.Run("ratio closest to max tick", func(t *testing.T) { + maxSqrtRatio := u256.MustFromDecimal(MAX_SQRT_RATIO) + maxSqrtRatioSub1 := maxSqrtRatio.Sub(maxSqrtRatio, u256.One()) + + maxTick := MAX_TICK + maxTickSub1 := maxTick - 1 + + rst := TickMathGetTickAtSqrtRatio(maxSqrtRatioSub1) + uassert.Equal(t, maxTickSub1, rst) + }) +} diff --git a/contract/r/gnoswap/common/util.gno b/contract/r/gnoswap/common/util.gno new file mode 100644 index 000000000..1fe2f40af --- /dev/null +++ b/contract/r/gnoswap/common/util.gno @@ -0,0 +1,33 @@ +package common + +import ( + "std" + + u256 "gno.land/p/gnoswap/uint256" +) + +// assertOnlyNotNil panics if the value is nil. +func assertOnlyNotNil(value *u256.Uint) { + if value == nil { + panic(newErrorWithDetail( + errInvalidInput, + "value is nil", + )) + } +} + +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrevAsString returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { + prev := getPrevRealm() + return prev.Addr().String(), prev.PkgPath() +} diff --git a/contract/r/gnoswap/community_pool/community_pool.gno b/contract/r/gnoswap/community_pool/community_pool.gno new file mode 100644 index 000000000..35bf4f3a6 --- /dev/null +++ b/contract/r/gnoswap/community_pool/community_pool.gno @@ -0,0 +1,81 @@ +package community_pool + +import ( + "std" + "strconv" + + "gno.land/r/gnoswap/v1/common" +) + +// TransferTokenByAdmin transfers token to the given address. +func TransferTokenByAdmin(tokenPath string, to std.Address, amount uint64) { + assertOnlyNotHalted() + assertOnlyAdmin() + + transferToken(tokenPath, to, amount) +} + +// TransferToken transfers token to the given address. +// Only governance contract can execute this function via proposal +func TransferToken(tokenPath string, to std.Address, amount uint64) { + assertOnlyNotHalted() + assertOnlyGovernance() + + transferToken(tokenPath, to, amount) +} + +// transferToken transfers token to the given address. +func transferToken(tokenPath string, to std.Address, amount uint64) { + teller := common.GetTokenTeller(tokenPath) + checkErr(teller.Transfer(to, amount)) + + prevAddr, prevPkgPath := getPrevAsString() + + std.Emit( + "TransferToken", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "tokenPath", tokenPath, + "to", to.String(), + "amount", strconv.FormatUint(amount, 10), + ) +} + +// checkErr panics if the error is not nil. +func checkErr(err error) { + if err != nil { + panic(err.Error()) + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertOnlyAdmin panics if the caller is not the admin. +func assertOnlyAdmin() { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } +} + +// assertOnlyGovernance panics if the caller is not the governance. +func assertOnlyGovernance() { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err) + } +} + +// getPrevAddr returns the address of the caller. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrevAsString returns the address and realm of the caller as a string. +func getPrevAsString() (string, string) { + prev := std.PrevRealm() + return prev.Addr().String(), prev.PkgPath() +} diff --git a/contract/r/gnoswap/community_pool/community_pool_test.gno b/contract/r/gnoswap/community_pool/community_pool_test.gno new file mode 100644 index 000000000..995aefeb3 --- /dev/null +++ b/contract/r/gnoswap/community_pool/community_pool_test.gno @@ -0,0 +1,219 @@ +package community_pool + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/gnoswap/v1/gns" +) + +var ( + adminAddr = consts.ADMIN + adminRealm = std.NewUserRealm(adminAddr) + + govRealm = std.NewCodeRealm(consts.GOV_GOVERNANCE_PATH) + + dummyCaller = std.NewUserRealm(testutils.TestAddress("dummyCaller")) + dummyReceiver = testutils.TestAddress("dummyReceiver") +) + +func TestTransferTokenByAdmin(t *testing.T) { + tests := []struct { + name string + setup func() + caller std.Realm + tokenPath string + to std.Address + amount uint64 + shouldPanic bool + panicMsg string + }{ + { + name: "panic if halted", + setup: func() { + std.TestSetRealm(adminRealm) + common.SetHaltByAdmin(true) + }, + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 1000, + shouldPanic: true, + panicMsg: "[GNOSWAP-COMMON-002] halted || GnoSwap is halted", + }, + { + name: "panic if not admin", + setup: func() { + std.TestSetRealm(adminRealm) + common.SetHaltByAdmin(false) + }, + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 1000, + shouldPanic: true, + panicMsg: "caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission", + }, + { + name: "panic if not enough balance", + setup: func() { + std.TestSetRealm(adminRealm) + gns.Transfer(consts.COMMUNITY_POOL_ADDR, 10_000) + }, + caller: adminRealm, + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 10_001, + shouldPanic: true, + panicMsg: "insufficient balance", + }, + { + name: "success if enough balance", + caller: adminRealm, + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.caller != (std.Realm{}) { + std.TestSetRealm(tt.caller) + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + TransferTokenByAdmin(tt.tokenPath, tt.to, tt.amount) + }) + } else { + receiverOldBalance := gns.BalanceOf(tt.to) + TransferTokenByAdmin(tt.tokenPath, tt.to, tt.amount) + receiverNewBalance := gns.BalanceOf(tt.to) + uassert.Equal(t, receiverNewBalance-receiverOldBalance, tt.amount) + } + }) + } +} + +func TestTransferToken(t *testing.T) { + tests := []struct { + name string + setup func() + caller std.Realm + tokenPath string + to std.Address + amount uint64 + shouldPanic bool + panicMsg string + }{ + { + name: "panic if halted", + setup: func() { + std.TestSetRealm(adminRealm) + common.SetHaltByAdmin(true) + }, + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 1000, + shouldPanic: true, + panicMsg: "[GNOSWAP-COMMON-002] halted || GnoSwap is halted", + }, + { + name: "panic if not governance", + setup: func() { + std.TestSetRealm(adminRealm) + common.SetHaltByAdmin(false) + }, + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 1000, + shouldPanic: true, + panicMsg: "caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission", + }, + { + name: "governance can't transfer community pool token", + setup: func() { + std.TestSetRealm(adminRealm) + }, + caller: govRealm, + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 10_001, + shouldPanic: true, + panicMsg: "insufficient balance", + }, + { + name: "governance can transfer community pool token", + caller: govRealm, + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.caller != (std.Realm{}) { + std.TestSetRealm(tt.caller) + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + TransferToken(tt.tokenPath, tt.to, tt.amount) + }) + } else { + receiverOldBalance := gns.BalanceOf(tt.to) + TransferToken(tt.tokenPath, tt.to, tt.amount) + receiverNewBalance := gns.BalanceOf(tt.to) + uassert.Equal(t, receiverNewBalance-receiverOldBalance, tt.amount) + } + }) + } +} + +func TestPrivateTransferToken(t *testing.T) { + tests := []struct { + tokenPath string + to std.Address + amount uint64 + shouldPanic bool + panicMsg string + }{ + { + tokenPath: "not_registered_token", + to: dummyReceiver, + amount: 1, + shouldPanic: true, + panicMsg: "unknown token: not_registered_token", + }, + { + tokenPath: consts.GNS_PATH, + to: dummyReceiver, + amount: 1, + shouldPanic: false, + }, + } + + for _, tt := range tests { + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + transferToken(tt.tokenPath, tt.to, tt.amount) + }) + } else { + transferToken(tt.tokenPath, tt.to, tt.amount) + } + } +} diff --git a/contract/r/gnoswap/community_pool/errors.gno b/contract/r/gnoswap/community_pool/errors.gno new file mode 100644 index 000000000..d5edb5d17 --- /dev/null +++ b/contract/r/gnoswap/community_pool/errors.gno @@ -0,0 +1,19 @@ +package community_pool + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-COMMUNITY_POOL-001] caller has no permission") + errNotRegistered = errors.New("[GNOSWAP-COMMUNITY_POOL-002] not registered") + errAlreadyRegistered = errors.New("[GNOSWAP-COMMUNITY_POOL-003] already registered") + errLocked = errors.New("[GNOSWAP-COMMUNITY_POOL-004] can't transfer token while locked") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/community_pool/gno.mod b/contract/r/gnoswap/community_pool/gno.mod new file mode 100644 index 000000000..72d91102b --- /dev/null +++ b/contract/r/gnoswap/community_pool/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/community_pool diff --git a/contract/r/gnoswap/emission/_helper_test.gno b/contract/r/gnoswap/emission/_helper_test.gno new file mode 100644 index 000000000..dbce2c788 --- /dev/null +++ b/contract/r/gnoswap/emission/_helper_test.gno @@ -0,0 +1,44 @@ +package emission + +import ( + "std" + "strconv" + "testing" + + "gno.land/p/demo/avl" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +var ( + adminRealm = std.NewUserRealm(consts.ADMIN) + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + govRealm = std.NewCodeRealm(consts.GOV_GOVERNANCE_PATH) + govStakerRealm = std.NewCodeRealm(consts.GOV_STAKER_PATH) +) + +func resetObject(t *testing.T) { + t.Helper() + + distributionBpsPct = avl.NewTree() + distributionBpsPct.Set(strconv.Itoa(LIQUIDITY_STAKER), uint64(7500)) + distributionBpsPct.Set(strconv.Itoa(DEVOPS), uint64(2000)) + distributionBpsPct.Set(strconv.Itoa(COMMUNITY_POOL), uint64(500)) + distributionBpsPct.Set(strconv.Itoa(GOV_STAKER), uint64(0)) + + distributedToStaker = 0 + distributedToDevOps = 0 + distributedToCommunityPool = 0 + distributedToGovStaker = 0 + accuDistributedToStaker = 0 + accuDistributedToDevOps = 0 + accuDistributedToCommunityPool = 0 + accuDistributedToGovStaker = 0 +} + +func gnsBalance(t *testing.T, addr std.Address) uint64 { + t.Helper() + + return gns.BalanceOf(addr) +} diff --git a/contract/r/gnoswap/emission/callback.gno b/contract/r/gnoswap/emission/callback.gno new file mode 100644 index 000000000..1e436eeb3 --- /dev/null +++ b/contract/r/gnoswap/emission/callback.gno @@ -0,0 +1,25 @@ +package emission + +import ( + "gno.land/r/gnoswap/v1/gns" +) + +var callbackStakerEmissionChange func(amount uint64) + +func SetCallbackStakerEmissionChange(callback func(amount uint64)) { + if callbackStakerEmissionChange != nil { + panic("callbackStakerEmissionChange already set") + } + callbackStakerEmissionChange = callback +} + +// Called when per-block emission is changed from the gns side. +// It does not process non-immediate emission changes, such as halving. +func callbackEmissionChange(amount uint64) { + callbackStakerEmissionChange(calculateAmount(amount, GetDistributionBpsPct(LIQUIDITY_STAKER))) +} + +func init() { + gns.SetCallbackEmissionChange(callbackEmissionChange) +} + diff --git a/contract/r/gnoswap/emission/distribution.gno b/contract/r/gnoswap/emission/distribution.gno new file mode 100644 index 000000000..5ac11d67c --- /dev/null +++ b/contract/r/gnoswap/emission/distribution.gno @@ -0,0 +1,304 @@ +package emission + +import ( + "std" + "strconv" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +const ( + LIQUIDITY_STAKER int = iota + 1 + DEVOPS + COMMUNITY_POOL + GOV_STAKER +) + +var ( + // Stores the percentage (in basis points) for each distribution target + // 1 basis point = 0.01% + distributionBpsPct *avl.Tree + + distributedToStaker uint64 // can be cleared by staker contract + distributedToDevOps uint64 + distributedToCommunityPool uint64 + distributedToGovStaker uint64 // can be cleared by governance staker + + // Historical total distributions (never reset) + accuDistributedToStaker uint64 + accuDistributedToDevOps uint64 + accuDistributedToCommunityPool uint64 + accuDistributedToGovStaker uint64 +) + +// Initialize default distribution percentages: +// - Liquidity Stakers: 75% +// - DevOps: 20% +// - Community Pool: 5% +// - Governance Stakers: 0% +func init() { + distributionBpsPct = avl.NewTree() + distributionBpsPct.Set(strconv.Itoa(LIQUIDITY_STAKER), uint64(7500)) + distributionBpsPct.Set(strconv.Itoa(DEVOPS), uint64(2000)) + distributionBpsPct.Set(strconv.Itoa(COMMUNITY_POOL), uint64(500)) + distributionBpsPct.Set(strconv.Itoa(GOV_STAKER), uint64(0)) +} + +// ChangeDistributionPctByAdmin changes the distribution percentage for the given targets. +// Panics if following conditions are not met: +// - caller is not admin +// - invalid target +// - sum of percentages is not 10000 +// - swap is halted +func ChangeDistributionPctByAdmin( + target01 int, pct01 uint64, + target02 int, pct02 uint64, + target03 int, pct03 uint64, + target04 int, pct04 uint64, +) { + assertOnlyAdmin() + assertDistributionTarget(target01) + assertDistributionTarget(target02) + assertDistributionTarget(target03) + assertDistributionTarget(target04) + assertSumDistributionPct(pct01, pct02, pct03, pct04) + assertOnlyNotHalted() + + changeDistributionPcts( + target01, pct01, + target02, pct02, + target03, pct03, + target04, pct04, + ) +} + +// ChangeDistributionPct changes the distribution percentage for the given targets. +// Panics if following conditions are not met: +// - caller is not governance +// - invalid target +// - sum of percentages is not 10000 +// - swap is halted +func ChangeDistributionPct( + target01 int, pct01 uint64, + target02 int, pct02 uint64, + target03 int, pct03 uint64, + target04 int, pct04 uint64, +) { + assertOnlyGovernance() + assertDistributionTarget(target01) + assertDistributionTarget(target02) + assertDistributionTarget(target03) + assertDistributionTarget(target04) + assertSumDistributionPct(pct01, pct02, pct03, pct04) + assertOnlyNotHalted() + + changeDistributionPcts( + target01, pct01, + target02, pct02, + target03, pct03, + target04, pct04, + ) +} + +func changeDistributionPcts( + target01 int, pct01 uint64, + target02 int, pct02 uint64, + target03 int, pct03 uint64, + target04 int, pct04 uint64, +) { + // First, cache the percentage of the staker just before it changes Callback if needed + // (check if the LIQUIDITY_STAKER was located between target01 and 04) + setDistributionBpsPct(target01, pct01) + setDistributionBpsPct(target02, pct02) + setDistributionBpsPct(target03, pct03) + setDistributionBpsPct(target04, pct04) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "ChangeDistributionPct", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "target01", targetToStr(target01), + "pct01", formatUint(pct01), + "target02", targetToStr(target02), + "pct02", formatUint(pct02), + "target03", targetToStr(target03), + "pct03", formatUint(pct03), + "target04", targetToStr(target04), + "pct04", formatUint(pct04), + ) +} + +// distributeToTarget splits an amount according to the configured percentages +// and sends tokens to each target address. Returns the total amount distributed. +func distributeToTarget(amount uint64) uint64 { + totalSent := uint64(0) + + distributionBpsPct.Iterate("", "", func(targetStr string, iPct interface{}) bool { + targetInt, err := strconv.Atoi(targetStr) + if err != nil { + panic(addDetailToError( + errInvalidEmissionTarget, + ufmt.Sprintf("invalid target(%s)", targetStr), + )) + } + + pct := iPct.(uint64) + distAmount := calculateAmount(amount, uint64(pct)) + totalSent += distAmount + + transferToTarget(targetInt, distAmount) + + return false + }) + + leftAmount := amount - totalSent + if leftGNSAmount > 0 { + setLeftGNSAmount(leftAmount) + } + + return totalSent +} + +// calculateAmount converts a basis point percentage to actual token amount +// bptPct is in basis points (1/100th of 1%) +// Example: 7500 basis points = 75% +func calculateAmount(amount, bptPct uint64) uint64 { + return amount * bptPct / 10000 +} + +// transferToTarget sends tokens to the appropriate address based on target type +// and updates both current and accumulated distribution tracking +func transferToTarget(target int, amount uint64) { + switch target { + case LIQUIDITY_STAKER: + gns.Transfer(consts.STAKER_ADDR, amount) + distributedToStaker += amount + accuDistributedToStaker += amount + + case DEVOPS: + gns.Transfer(consts.DEV_OPS, amount) + distributedToDevOps += amount + accuDistributedToDevOps += amount + + case COMMUNITY_POOL: + gns.Transfer(consts.COMMUNITY_POOL_ADDR, amount) + distributedToCommunityPool += amount + accuDistributedToCommunityPool += amount + + case GOV_STAKER: + gns.Transfer(consts.GOV_STAKER_ADDR, amount) + distributedToGovStaker += amount + accuDistributedToGovStaker += amount + + default: + panic(addDetailToError( + errInvalidEmissionTarget, + ufmt.Sprintf("invalid target(%d)", target), + )) + } +} + +func GetDistributionBpsPct(target int) uint64 { + assertDistributionTarget(target) + iUint64, exist := distributionBpsPct.Get(strconv.Itoa(target)) + if !exist { + panic(addDetailToError( + errInvalidEmissionTarget, + ufmt.Sprintf("invalid target(%d)", target), + )) + } + + return iUint64.(uint64) +} + +func GetDistributedToStaker() uint64 { + return distributedToStaker +} + +func GetDistributedToDevOps() uint64 { + return distributedToDevOps +} + +func GetDistributedToCommunityPool() uint64 { + return distributedToCommunityPool +} + +func GetDistributedToGovStaker() uint64 { + return distributedToGovStaker +} + +func GetAccuDistributedToStaker() uint64 { + return accuDistributedToStaker +} + +func GetAccuDistributedToDevOps() uint64 { + return accuDistributedToDevOps +} + +func GetAccuDistributedToCommunityPool() uint64 { + return accuDistributedToCommunityPool +} + +func GetAccuDistributedToGovStaker() uint64 { + return accuDistributedToGovStaker +} + +func ClearDistributedToStaker() { + assertStakerOnly() + distributedToStaker = 0 +} + +func ClearDistributedToGovStaker() { + assertOnlyGovStaker() + distributedToGovStaker = 0 +} + +func setDistributionBpsPct(target int, pct uint64) { + if target == LIQUIDITY_STAKER { + oldPct, exist := distributionBpsPct.Get(strconv.Itoa(target)) + if !exist { + panic("should not happen") + } + + if oldPct.(uint64) != pct { + callbackStakerEmissionChange(calculateAmount(gns.GetEmission(), pct)) + } + } + + distributionBpsPct.Set(strconv.Itoa(target), pct) +} + +func GetEmission() uint64 { + return calculateAmount(gns.GetEmission(), GetDistributionBpsPct(LIQUIDITY_STAKER)) +} + +// When there is a staker % change, it will trigger the StakerCallback first, which should "flush" the halving blocks first(by poolTier calling the GetHalvingBlocksInRange before it updates the emission). +// Therefore, if there are halving blocks exist from what we got from gns side, there must be no % change event between those blocks, which means we can use the current distribution pct to calculate the emissions for the past halving blocks. +func GetHalvingBlocksInRange(start, end int64) ([]int64, []uint64) { + halvingBlocks, halvingEmissions := gns.GetHalvingBlocksInRange(start, end) + for i := range halvingBlocks { + // Applying staker ratio for past halving blocks + halvingEmissions[i] = calculateAmount(halvingEmissions[i], GetDistributionBpsPct(LIQUIDITY_STAKER)) + } + return halvingBlocks, halvingEmissions +} + +func targetToStr(target int) string { + switch target { + case LIQUIDITY_STAKER: + return "LIQUIDITY_STAKER" + case DEVOPS: + return "DEVOPS" + case COMMUNITY_POOL: + return "COMMUNITY_POOL" + case GOV_STAKER: + return "GOV_STAKER" + default: + return "UNKNOWN" + } +} diff --git a/contract/r/gnoswap/emission/distribution_test.gno b/contract/r/gnoswap/emission/distribution_test.gno new file mode 100644 index 000000000..f81076c8e --- /dev/null +++ b/contract/r/gnoswap/emission/distribution_test.gno @@ -0,0 +1,470 @@ +package emission + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +func TestChangeDistributionPctByAdmin(t *testing.T) { + resetObject(t) + + originCallback := callbackStakerEmissionChange + callbackStakerEmissionChange = func(amount uint64) {} + + tests := []struct { + name string + shouldPanic bool + panicMsg string + setup func() + callerRealm std.Realm + targets []int + pcts []uint64 + verify func() + }{ + { + name: "panic if caller is not admin", + shouldPanic: true, + panicMsg: `caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission`, + targets: []int{1, 2, 3, 4}, + pcts: []uint64{1000, 2000, 3000, 4000}, + }, + { + name: "panic if target is invalid", + shouldPanic: true, + panicMsg: "[GNOSWAP-EMISSION-002] invalid emission target || invalid target(9)", + callerRealm: adminRealm, + targets: []int{1, 2, 3, 9}, + pcts: []uint64{1000, 2000, 3000, 4000}, + }, + { + name: "panic if sum of percentages is not 100%", + shouldPanic: true, + panicMsg: "[GNOSWAP-EMISSION-003] invalid emission percentage || sum of all pct should be 100% (10000 bps), got 10001", + callerRealm: adminRealm, + targets: []int{1, 2, 3, 4}, + pcts: []uint64{1000, 2000, 3000, 4001}, + }, + { + name: "success if admin", + shouldPanic: false, + callerRealm: adminRealm, + targets: []int{1, 2, 3, 4}, + pcts: []uint64{1000, 2000, 3000, 4000}, + verify: func() { + uassert.Equal(t, uint64(1000), GetDistributionBpsPct(int(1))) + uassert.Equal(t, uint64(2000), GetDistributionBpsPct(int(2))) + uassert.Equal(t, uint64(3000), GetDistributionBpsPct(int(3))) + uassert.Equal(t, uint64(4000), GetDistributionBpsPct(int(4))) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.callerRealm != (std.Realm{}) { + std.TestSetRealm(tt.callerRealm) + } + + target01, target02, target03, target04 := sliceToFourInt(t, tt.targets) + pct01, pct02, pct03, pct04 := sliceToFourUint64(t, tt.pcts) + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + ChangeDistributionPctByAdmin( + target01, pct01, + target02, pct02, + target03, pct03, + target04, pct04, + ) + }) + } else { + uassert.NotPanics(t, func() { + ChangeDistributionPctByAdmin( + target01, pct01, + target02, pct02, + target03, pct03, + target04, pct04, + ) + }) + tt.verify() + } + }) + } + callbackStakerEmissionChange = originCallback +} + +func TestChangeDistributionPct(t *testing.T) { + resetObject(t) + + originCallback := callbackStakerEmissionChange + callbackStakerEmissionChange = func(amount uint64) {} + + tests := []struct { + name string + shouldPanic bool + panicMsg string + setup func() + callerRealm std.Realm + targets []int + pcts []uint64 + verify func() + }{ + { + name: "panic if caller is not governance", + shouldPanic: true, + panicMsg: `caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission`, + targets: []int{1, 2, 3, 4}, + pcts: []uint64{1000, 2000, 3000, 4000}, + }, + { + name: "panic if target is invalid", + shouldPanic: true, + panicMsg: "[GNOSWAP-EMISSION-002] invalid emission target || invalid target(9)", + callerRealm: govRealm, + targets: []int{1, 2, 3, 9}, + pcts: []uint64{1000, 2000, 3000, 4000}, + }, + { + name: "panic if sum of percentages is not 100%", + shouldPanic: true, + panicMsg: "[GNOSWAP-EMISSION-003] invalid emission percentage || sum of all pct should be 100% (10000 bps), got 10001", + callerRealm: govRealm, + targets: []int{1, 2, 3, 4}, + pcts: []uint64{1000, 2000, 3000, 4001}, + }, + { + name: "success if governance", + shouldPanic: false, + callerRealm: govRealm, + targets: []int{1, 2, 3, 4}, + pcts: []uint64{1000, 2000, 3000, 4000}, + verify: func() { + uassert.Equal(t, uint64(1000), GetDistributionBpsPct(1)) + uassert.Equal(t, uint64(2000), GetDistributionBpsPct(2)) + uassert.Equal(t, uint64(3000), GetDistributionBpsPct(3)) + uassert.Equal(t, uint64(4000), GetDistributionBpsPct(4)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.callerRealm != (std.Realm{}) { + std.TestSetRealm(tt.callerRealm) + } + + target01, target02, target03, target04 := sliceToFourInt(t, tt.targets) + pct01, pct02, pct03, pct04 := sliceToFourUint64(t, tt.pcts) + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + ChangeDistributionPct( + target01, pct01, + target02, pct02, + target03, pct03, + target04, pct04, + ) + }) + } else { + uassert.NotPanics(t, func() { + ChangeDistributionPct( + target01, pct01, + target02, pct02, + target03, pct03, + target04, pct04, + ) + }) + tt.verify() + } + }) + } + callbackStakerEmissionChange = originCallback +} + +func TestChangeDistributionPcts(t *testing.T) { + resetObject(t) + + originCallback := callbackStakerEmissionChange + callbackStakerEmissionChange = func(amount uint64) {} + + changeDistributionPcts( + 1, 1000, + 2, 2000, + 3, 3000, + 4, 4000, + ) + uassert.Equal(t, uint64(1000), GetDistributionBpsPct(1)) + uassert.Equal(t, uint64(2000), GetDistributionBpsPct(2)) + uassert.Equal(t, uint64(3000), GetDistributionBpsPct(3)) + uassert.Equal(t, uint64(4000), GetDistributionBpsPct(4)) + + callbackStakerEmissionChange = originCallback +} + +func TestCalculateAmount(t *testing.T) { + tests := []struct { + name string + pct uint64 + expected uint64 + }{ + {name: "5% of 1_000", pct: 500, expected: 50}, + {name: "10% of 1_000", pct: 1000, expected: 100}, + {name: "55% of 1_000", pct: 5500, expected: 550}, + {name: "100% of 1_000", pct: 10000, expected: 1000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uassert.Equal(t, tt.expected, calculateAmount(uint64(1000), tt.pct)) + }) + } +} + +func TestTransferToTarget(t *testing.T) { + resetObject(t) + + tests := []struct { + name string + shouldPanic bool + panicMsg string + setup func() + target int + amount uint64 + verify func() + }{ + { + name: "invalid target", + shouldPanic: true, + panicMsg: "[GNOSWAP-EMISSION-002] invalid emission target || invalid target(9)", + target: 9, + amount: 100, + }, + { + name: "not enough balance for emission", + shouldPanic: true, + panicMsg: "insufficient balance", + target: LIQUIDITY_STAKER, + amount: 1, + }, + { + name: "transfer to LIQUIDITY_STAKER", + target: LIQUIDITY_STAKER, + setup: func() { + std.TestSetRealm(adminRealm) + gns.Transfer(consts.EMISSION_ADDR, 100000) // give enough balance for emission + }, + amount: 100, + verify: func() { + uassert.Equal(t, uint64(100), distributedToStaker) + uassert.Equal(t, uint64(100), accuDistributedToStaker) + }, + }, + { + name: "transfer to DEVOPS", + target: DEVOPS, + amount: 200, + verify: func() { + uassert.Equal(t, uint64(200), distributedToDevOps) + uassert.Equal(t, uint64(200), accuDistributedToDevOps) + }, + }, + { + name: "transfer to COMMUNITY_POOL", + target: COMMUNITY_POOL, + amount: 300, + verify: func() { + uassert.Equal(t, uint64(300), distributedToCommunityPool) + uassert.Equal(t, uint64(300), accuDistributedToCommunityPool) + }, + }, + { + name: "transfer to GOV_STAKER", + target: GOV_STAKER, + amount: 400, + verify: func() { + uassert.Equal(t, uint64(400), distributedToGovStaker) + uassert.Equal(t, uint64(400), accuDistributedToGovStaker) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + transferToTarget(tt.target, tt.amount) + }) + } else { + uassert.NotPanics(t, func() { + transferToTarget(tt.target, tt.amount) + }) + tt.verify() + } + }) + } +} + +func TestDistributeToTarget(t *testing.T) { + resetObject(t) + + tests := []struct { + name string + shouldPanic bool + panicMsg string + setup func() + amount uint64 + expectedLeft uint64 + }{ + { + name: "invalid target", + shouldPanic: true, + panicMsg: "[GNOSWAP-EMISSION-002] invalid emission target || invalid target(a)", + setup: func() { + distributionBpsPct.Set("a", uint64(1000)) + }, + amount: 100, + }, + { + name: "distributed all amount", + setup: func() { + distributionBpsPct.Remove("a") + }, + amount: 1000, + expectedLeft: 0, + }, + { + name: "distributed partial amount", + amount: 1001, + expectedLeft: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + distributeToTarget(tt.amount) + }) + } else { + distributed := distributeToTarget(tt.amount) + left := tt.amount - distributed + uassert.Equal(t, tt.expectedLeft, left) + } + }) + } +} + +func TestClearDistributedToStaker(t *testing.T) { + distributedToStaker = 100 + + tests := []struct { + name string + expected uint64 + callerRealm std.Realm + shouldPanic bool + panicMsg string + }{ + { + name: "can not clear is caller is not staker", + shouldPanic: true, + panicMsg: `caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission`, + }, + { + name: "can clear if caller is staker", + callerRealm: stakerRealm, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.callerRealm != (std.Realm{}) { + std.TestSetRealm(tt.callerRealm) + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + ClearDistributedToStaker() + }) + } else { + ClearDistributedToStaker() + uassert.Equal(t, uint64(0), distributedToStaker) + } + }) + } + +} + +func TestClearClearDistributedToGovStaker(t *testing.T) { + distributedToGovStaker = 100 + + tests := []struct { + name string + expected uint64 + callerRealm std.Realm + shouldPanic bool + panicMsg string + }{ + { + name: "can not clear is caller is not gov/staker", + shouldPanic: true, + panicMsg: `caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission`, + }, + { + name: "can clear if caller is gov/taker", + callerRealm: govStakerRealm, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.callerRealm != (std.Realm{}) { + std.TestSetRealm(tt.callerRealm) + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + ClearDistributedToGovStaker() + }) + } else { + ClearDistributedToGovStaker() + uassert.Equal(t, uint64(0), distributedToGovStaker) + } + }) + } + +} + +func sliceToFourInt(t *testing.T, slice []int) (int, int, int, int) { + t.Helper() + + return slice[0], slice[1], slice[2], slice[3] +} + +func sliceToFourUint64(t *testing.T, slice []uint64) (uint64, uint64, uint64, uint64) { + t.Helper() + + return slice[0], slice[1], slice[2], slice[3] +} diff --git a/contract/r/gnoswap/emission/emission.gno b/contract/r/gnoswap/emission/emission.gno new file mode 100644 index 000000000..29f19e9da --- /dev/null +++ b/contract/r/gnoswap/emission/emission.gno @@ -0,0 +1,88 @@ +package emission + +import ( + "std" + "time" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +var ( + // leftGNSAmount tracks undistributed GNS tokens from previous distributions + leftGNSAmount uint64 + // lastExecutedHeight stores the last block height when distribution was executed + lastExecutedHeight int64 +) + +// MintAndDistributeGns mints new GNS tokens and distributes them to targets +// Returns the total amount of GNS distributed +func MintAndDistributeGns() uint64 { + assertOnlyNotHalted() + + currentHeight := std.GetHeight() + lastMintedHeight := gns.GetLastMintedHeight() + if lastMintedHeight >= currentHeight { + // Skip if we've already minted tokens at this height + return 0 + } + // Mint new tokens and add any leftover amounts from previous distribution + mintedEmissionRewardAmount := gns.MintGns(a2u(consts.EMISSION_ADDR)) + + distributableAmount := mintedEmissionRewardAmount + prevLeftAmount := GetLeftGNSAmount() + if hasLeftGNSAmount() { + distributableAmount += prevLeftAmount + setLeftGNSAmount(0) + } + // Distribute tokens and track any undistributed amount + distributedGNSAmount := distributeToTarget(distributableAmount) + if distributableAmount != distributedGNSAmount { + setLeftGNSAmount(distributableAmount - distributedGNSAmount) + } + + // Emit event with distribution details + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "MintAndDistributeGns", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "lastHeight", formatInt(lastExecutedHeight), + "currentHeight", formatInt(currentHeight), + "currentTimestamp", formatInt(time.Now().Unix()), + "mintedAmount", formatUint(mintedEmissionRewardAmount), + "prevLeftAmount", formatUint(prevLeftAmount), + "distributedAmount", formatUint(distributedGNSAmount), + "currentLeftAmount", formatUint(GetLeftGNSAmount()), + "gnsTotalSupply", formatUint(gns.TotalSupply()), + ) + + setLastExecutedHeight(currentHeight) + + return distributedGNSAmount +} + +// GetLeftGNSAmount returns the amount of undistributed GNS tokens +func GetLeftGNSAmount() uint64 { + return leftGNSAmount +} + +// setLeftGNSAmount updates the undistributed GNS token amount +func setLeftGNSAmount(amount uint64) { + leftGNSAmount = amount +} + +// GetLastExecutedHeight returns the last block height when distribution was executed +func GetLastExecutedHeight() int64 { + return lastExecutedHeight +} + +// setLastExecutedHeight updates the last executed block height +func setLastExecutedHeight(height int64) { + lastExecutedHeight = height +} + +// hasLeftGNSAmount checks if there are any undistributed GNS tokens +func hasLeftGNSAmount() bool { + return leftGNSAmount > 0 +} diff --git a/contract/r/gnoswap/emission/emission_test.gno b/contract/r/gnoswap/emission/emission_test.gno new file mode 100644 index 000000000..f18fb6644 --- /dev/null +++ b/contract/r/gnoswap/emission/emission_test.gno @@ -0,0 +1,59 @@ +package emission + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" +) + +func TestMintAndDistributeGns(t *testing.T) { + tests := []struct { + name string + setup func() + shouldPanic bool + panicMsg string + verify func(distributed uint64) + }{ + { + name: "no block passed", + verify: func(distributed uint64) { + uassert.Equal(t, uint64(0), distributed) + }, + }, + { + name: "block passed", + setup: func() { + std.TestSkipHeights(123) + }, + verify: func(distributed uint64) { + uassert.True(t, distributed > 0) + }, + }, + { + name: "same block", + verify: func(distributed uint64) { + uassert.Equal(t, uint64(0), distributed) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + MintAndDistributeGns() + }) + } else { + distributed := MintAndDistributeGns() + if tt.verify != nil { + tt.verify(distributed) + } + } + }) + } +} diff --git a/contract/r/gnoswap/emission/errors.gno b/contract/r/gnoswap/emission/errors.gno new file mode 100644 index 000000000..c657d0916 --- /dev/null +++ b/contract/r/gnoswap/emission/errors.gno @@ -0,0 +1,18 @@ +package emission + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errCallbackIsNil = errors.New("[GNOSWAP-EMISSION-001] callback func is nil") + errInvalidEmissionTarget = errors.New("[GNOSWAP-EMISSION-002] invalid emission target") + errInvalidEmissionPct = errors.New("[GNOSWAP-EMISSION-003] invalid emission percentage") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/emission/gno.mod b/contract/r/gnoswap/emission/gno.mod new file mode 100644 index 000000000..fab19611b --- /dev/null +++ b/contract/r/gnoswap/emission/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/emission diff --git a/contract/r/gnoswap/emission/utils.gno b/contract/r/gnoswap/emission/utils.gno new file mode 100644 index 000000000..b69946535 --- /dev/null +++ b/contract/r/gnoswap/emission/utils.gno @@ -0,0 +1,113 @@ +package emission + +import ( + "std" + "strconv" + + "gno.land/p/demo/ufmt" + + "gno.land/r/gnoswap/v1/common" +) + +const MAX_BPS_PCT uint64 = 10000 + +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrev returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { + prev := std.PrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// assertOnlyAdmin panics if the caller is not the admin. +func assertOnlyAdmin() { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err.Error()) + } +} + +// assertStakerOnly panics if the caller is not the staker. +func assertStakerOnly() { + caller := getPrevAddr() + if err := common.StakerOnly(caller); err != nil { + panic(err.Error()) + } +} + +// assertOnlyGovernance panics if the caller is not the governance. +func assertOnlyGovernance() { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err.Error()) + } +} + +// assertOnlyGovStaker panics if the caller is not the gov staker. +func assertOnlyGovStaker() { + caller := getPrevAddr() + if err := common.GovStakerOnly(caller); err != nil { + panic(err.Error()) + } +} + +// assertDistributionTarget panics if the target is invalid. +func assertDistributionTarget(target int) { + if target != LIQUIDITY_STAKER && target != DEVOPS && target != COMMUNITY_POOL && target != GOV_STAKER { + panic(addDetailToError( + errInvalidEmissionTarget, + ufmt.Sprintf("invalid target(%d)", target), + )) + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertSumDistributionPct ensures the sum of all distribution percentages is 100% +func assertSumDistributionPct(pct01, pct02, pct03, pct04 uint64) { + sum := pct01 + pct02 + pct03 + pct04 + + if sum != MAX_BPS_PCT { + panic(addDetailToError( + errInvalidEmissionPct, + ufmt.Sprintf("sum of all pct should be 100%% (10000 bps), got %d", sum), + )) + } +} + +func formatUint(v interface{}) string { + switch v := v.(type) { + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatInt(v interface{}) string { + switch v := v.(type) { + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case int: + return strconv.Itoa(v) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} diff --git a/contract/r/gnoswap/gnft/errors.gno b/contract/r/gnoswap/gnft/errors.gno new file mode 100644 index 000000000..6565d9c65 --- /dev/null +++ b/contract/r/gnoswap/gnft/errors.gno @@ -0,0 +1,19 @@ +package gnft + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-GNFT-001] caller has no permission") + errCannotSetURI = errors.New("[GNOSWAP-GNFT-002] cannot set URI") + errNoTokenForCaller = errors.New("[GNOSWAP-GNFT-003] no token for caller") + errInvalidAddress = errors.New("[GNOSWAP-GNFT-004] invalid addresss") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/gnft/gnft.gno b/contract/r/gnoswap/gnft/gnft.gno new file mode 100644 index 000000000..31d515cb7 --- /dev/null +++ b/contract/r/gnoswap/gnft/gnft.gno @@ -0,0 +1,507 @@ +package gnft + +import ( + "math/rand" + "std" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +var ( + gnft = grc721.NewBasicNFT("GNOSWAP NFT", "GNFT") +) + +var ( + owner *ownable.Ownable + tokenList *avl.Tree // addr -> []grc721.TokenID +) + +func init() { + owner = ownable.NewWithAddress(consts.POSITION_ADDR) // deployed position contract + tokenList = avl.NewTree() +} + +// Name returns the full name of the NFT +// Returns: +// - string: The name of the NFT collection +func Name() string { + return gnft.Name() +} + +// Symbol returns the token symbol of the NFT +// Returns: +// - string: The symbol of the NFT collection +func Symbol() string { + return gnft.Symbol() +} + +// TotalSupply returns the total number of NFTs minted +// Returns: +// - uint64: The total number of tokens that have been minted +func TotalSupply() uint64 { + return gnft.TokenCount() +} + +// TokenURI retrieves the metadata URI for a specific token ID +// Parameters: +// - tid: The unique identifier of the token +// +// Returns: +// - string: The metadata URI associated with the token +func TokenURI(tid grc721.TokenID) string { + uri, err := gnft.TokenURI(tid) + if err != nil { + panic(err.Error()) + } + + return string(uri) +} + +// BalanceOf returns the number of NFTs owned by the specified address. +// Parameters: +// - owner (std.Address): The address to check the NFT balance for. +// +// Returns: +// - uint64: The number of NFTs owned by the address. +// - error: Returns an error if the balance retrieval fails. +func BalanceOf(owner std.Address) (uint64, error) { + balance, err := gnft.BalanceOf(owner) + if err != nil { + panic(err.Error()) + } + return balance, nil +} + +// OwnerOf returns the current owner's address of a specific token ID +// Parameters: +// - tid: The token ID to check ownership of +// +// Returns: +// - std.Address: The address of the token owner +func OwnerOf(tid grc721.TokenID) (std.Address, error) { + ownerAddr, err := gnft.OwnerOf(tid) + if err != nil { + return "", err + } + + return ownerAddr, nil +} + +// MustOwnerOf returns the current owner's address of a specific token ID +// Parameters: +// - tid: The token ID to check ownership of +// +// Returns: +// - std.Address: The address of the token owner +// +// Panics: +// - If the token ID is invalid +func MustOwnerOf(tid grc721.TokenID) std.Address { + ownerAddr, err := OwnerOf(tid) + if err != nil { + panic(err.Error()) + } + + return ownerAddr +} + +// SetTokenURI sets the metadata URI using a randomly generated SVG image +// Parameters: +// - tid (grc721.TokenID): The token ID for which the URI will be updated. +// - tURI (grc721.TokenURI): The new metadata URI to associate with the token. +// +// Returns: +// - bool: Returns `true` if the operation is successful. +// - error: Returns an error if the operation fails or the caller is not authorized. +// +// Panics: +// - If the caller is not the token owner, the function panics. +// - If the URI update fails, the function panics with the associated error. +func SetTokenURI(tid grc721.TokenID, tURI grc721.TokenURI) (bool, error) { + assertOnlyNotHalted() + assertCallerIsOwnerOfToken(tid) + + err := setTokenURI(tid, tURI) + if err != nil { + panic(addDetailToError( + errCannotSetURI, + ufmt.Sprintf("token id (%s)", tid), + )) + } + return true, nil +} + +// SafeTransferFrom securely transfers ownership of a token from one address to another. +// +// This function enforces several checks to ensure the transfer is valid and authorized: +// - Ensures the contract is not halted. +// - Validates the addresses involved in the transfer. +// - Checks that the caller is the token owner or has been approved to transfer the token. +// +// After validation, the function updates the internal token lists by removing the token from the sender's list +// and appending it to the recipient's list. It then calls the underlying transfer logic through `gnft.TransferFrom`. +// +// Parameters: +// - from (std.Address): The current owner's address of the token being transferred. +// - to (std.Address): The recipient's address to receive the token. +// - tid (grc721.TokenID): The ID of the token to be transferred. +// +// Returns: +// - error: Returns `nil` if the transfer is successful; otherwise, it raises an error. +// +// Panics: +// - If the contract is halted. +// - If either `from` or `to` addresses are invalid. +// - If the caller is not the owner or approved operator of the token. +// - If the internal transfer (`gnft.TransferFrom`) fails. +func SafeTransferFrom(from, to std.Address, tid grc721.TokenID) error { + assertOnlyNotHalted() + + assertValidAddr(from) + assertValidAddr(to) + + caller := getPrevAddr() + ownerAddr, _ := OwnerOf(tid) + approved, _ := GetApproved(tid) + if (caller != ownerAddr) && (caller != approved) { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("caller (%s) is not the owner or operator of token (%s)", caller, string(tid)), + )) + } + + removeTokenList(from, tid) + appendTokenList(to, tid) + + checkErr(gnft.TransferFrom(from, to, tid)) + return nil +} + +// TransferFrom transfers a token from one address to another +// This function is a direct wrapper around `SafeTransferFrom`, which performs the actual transfer. +// +// Parameters: +// - from (std.Address): The current owner's address of the token being transferred. +// - to (std.Address): The recipient's address to receive the token. +// - tid (grc721.TokenID): The ID of the token to be transferred. +// +// Returns: +// - error: Returns `nil` if the transfer is successful; otherwise, returns an error. +func TransferFrom(from, to std.Address, tid grc721.TokenID) error { + return SafeTransferFrom(from, to, tid) +} + +// Approve grants permission to transfer a specific token ID to another address. +// +// Parameters: +// - approved (std.Address): The address to grant transfer approval to. +// - tid (grc721.TokenID): The token ID to approve for transfer. +// +// Returns: +// - error: Returns `nil` if the approval is successful, otherwise returns an error. +// +// Panics: +// - If the contract is halted. +// - If the caller is not the token owner. +// - If the `Approve` call fails. +func Approve(approved std.Address, tid grc721.TokenID) error { + assertOnlyNotHalted() + assertCallerIsOwnerOfToken(tid) + + err := gnft.Approve(approved, tid) + if err != nil { + panic(err.Error()) + } + return nil +} + +// SetApprovalForAll enables or disables approval for a third party (`operator`) to manage all tokens owned by the caller. +// +// Parameters: +// - operator (std.Address): The address to grant or revoke operator permissions for. +// - approved (bool): `true` to enable approval, `false` to revoke approval. +// +// Returns: +// - error: Returns `nil` if the operation is successful, otherwise returns an error. +// +// Panics: +// - If the contract is halted. +// - If the `SetApprovalForAll` operation fails. +func SetApprovalForAll(operator std.Address, approved bool) error { + assertOnlyNotHalted() + checkErr(gnft.SetApprovalForAll(operator, approved)) + return nil +} + +// GetApproved returns the approved address for a specific token ID. +// +// Parameters: +// - tid (grc721.TokenID): The token ID to check for approval. +// +// Returns: +// - std.Address: The address approved to manage the token. Returns an empty address if no approval exists. +// - error: Returns an error if the lookup fails or the token ID is invalid. +func GetApproved(tid grc721.TokenID) (std.Address, error) { + addr, err := gnft.GetApproved(tid) + if err != nil { + return "", err + } + + return addr, nil +} + +// IsApprovedForAll checks if an operator is approved to manage all tokens of an owner. +// +// Parameters: +// - owner (std.Address): The address of the token owner. +// - operator (std.Address): The address to check if it has approval to manage the owner's tokens. +// +// Returns: +// - bool: true if the operator is approved to manage all tokens of the owner, false otherwise. +func IsApprovedForAll(owner, operator std.Address) bool { + return gnft.IsApprovedForAll(owner, operator) +} + +// SetTokenURIByImageURI generates and sets a new token URI for a specified token ID using a random image URI. +// +// Parameters: +// - tid (grc721.TokenID): The ID of the token for which the URI will be set. +// +// Panics: +// - If the contract is halted. +// - If the caller is not the owner of the token. +// - If the token URI cannot be set. +func SetTokenURIByImageURI(tid grc721.TokenID) { + assertOnlyNotHalted() + assertCallerIsOwnerOfToken(tid) + + tokenURI := genImageURI(generateRandInstance()) + + err := setTokenURI(tid, grc721.TokenURI(tokenURI)) + if err != nil { + panic(addDetailToError( + errCannotSetURI, + ufmt.Sprintf("%s (%s)", err.Error(), string(tid)), + )) + } +} + +// SetTokenURILast sets the token URI for the last token owned by the caller using a randomly generated image URI. +// +// This function ensures the contract is active and the caller owns at least one token. +// It retrieves the list of tokens owned by the caller and applies a new token URI to the most recently minted token. +// +// Panics: +// - If the contract is halted. +// - If the caller does not own any tokens (empty token list). +// - If URI generation or assignment fails. +func SetTokenURILast() { + assertOnlyNotHalted() + + caller := getPrevAddr() + tokenListByCaller, _ := getTokenList(caller) + lenTokenListByCaller := len(tokenListByCaller) + if lenTokenListByCaller == 0 { + panic(addDetailToError( + errNoTokenForCaller, + ufmt.Sprintf("caller (%s)", caller), + )) + } + + lastTokenId := tokenListByCaller[lenTokenListByCaller-1] + SetTokenURIByImageURI(lastTokenId) +} + +// Mint creates a new NFT and assigns it to the specified address (only callable by owner) +// Parameters: +// - to: The address or username to mint the token to +// - tid: The token ID to assign to the new NFT +// +// Returns: +// - grc721.TokenID: The ID of the newly minted token +func Mint(to std.Address, tid grc721.TokenID) grc721.TokenID { + owner.AssertCallerIsOwner() + assertOnlyNotHalted() + + checkErr(gnft.Mint(to, tid)) + + appendTokenList(to, tid) + return tid +} + +// Burn removes a specific token ID (only callable by owner) +// Parameters: +// - tid: The token ID to burn +func Burn(tid grc721.TokenID) { + owner.AssertCallerIsOwner() + assertOnlyNotHalted() + + ownerAddr, err := OwnerOf(tid) + if err != nil { + panic(err.Error()) + } + removeTokenList(ownerAddr, tid) + + checkErr(gnft.Burn(tid)) +} + +// Render returns the HTML representation of the NFT +// Parameters: +// - path: The path to render +// +// Returns: +// - string: HTML representation of the NFT or 404 if path is invalid +func Render(path string) string { + switch { + case path == "": + return gnft.RenderHome() + default: + return "404\n" + } +} + +// setTokenURI sets the metadata URI for a specific token ID +func setTokenURI(tid grc721.TokenID, tURI grc721.TokenURI) error { + assertOnlyEmptyTokenURI(tid) + _, err := gnft.SetTokenURI(tid, tURI) + if err != nil { + return err + } + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "SetTokenURI", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "lpTokenId", string(tid), + "tokenURI", string(tURI), + ) + + return nil +} + +// generateRandInstnace generates a new random instance +// Returns: +// - *rand.Rand: A new random instance +func generateRandInstance() *rand.Rand { + seed1 := uint64(time.Now().Unix()) + TotalSupply() + seed2 := uint64(time.Now().UnixNano()) + TotalSupply() + pcg := rand.NewPCG(seed1, seed2) + return rand.New(pcg) +} + +// getTokenList retrieves the list of nft tokens for an address +// Parameters: +// - addr: The address to check for nft tokens +// +// Returns: +// - []grc721.TokenID: Array of token IDs +// - bool: true if tokens exist for the address, false otherwise +func getTokenList(addr std.Address) ([]grc721.TokenID, bool) { + iTokens, exists := tokenList.Get(addr.String()) + if !exists { + return []grc721.TokenID{}, false + } + + return iTokens.([]grc721.TokenID), true +} + +// mustGetTokenList same as getTokenList but panics if tokens don't exist +// Parameters: +// - addr: The address to check for nft tokens +// +// Returns: +// - []grc721.TokenID: Array of token IDs +func mustGetTokenList(addr std.Address) []grc721.TokenID { + tokens, exists := getTokenList(addr) + if !exists { + panic(ufmt.Sprintf("user %s has no minted nft tokens", addr.String())) + } + + return tokens +} + +// appendTokenList adds a token ID to the list of nft tokens +// Parameters: +// - addr: The address to append the token for +// - tid: The token ID to append +func appendTokenList(addr std.Address, tid grc721.TokenID) { + prevTokenList, _ := getTokenList(addr) + prevTokenList = append(prevTokenList, tid) + tokenList.Set(addr.String(), prevTokenList) +} + +// removeTokenList removes a token ID from the list of nft tokens +// Parameters: +// - addr: The address to remove the token for +// - tid: The token ID to remove +func removeTokenList(addr std.Address, tid grc721.TokenID) { + prevTokenList, exist := getTokenList(addr) + if !exist { + return + } + + for i, token := range prevTokenList { + if token == tid { + prevTokenList = append(prevTokenList[:i], prevTokenList[i+1:]...) + break + } + } + + tokenList.Set(addr.String(), prevTokenList) +} + +// checkErr helper function to panic if an error occurs +// Parameters: +// - err: The error to check +func checkErr(err error) { + if err != nil { + panic(err.Error()) + } +} + +// assertCallerIsOwnerOfToken asserts that the caller is the owner of the token +// Parameters: +// - tid: The token ID to check ownership of +func assertCallerIsOwnerOfToken(tid grc721.TokenID) { + caller := getPrevAddr() + owner, _ := OwnerOf(tid) + if caller != owner { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("caller (%s) is not the owner of token (%s)", caller, string(tid)), + )) + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertValidAddr panics if the address is invalid.s +func assertValidAddr(addr std.Address) { + if !addr.IsValid() { + panic(addDetailToError( + errInvalidAddress, + addr.String(), + )) + } +} + +// assertOnlyEmptyTokenURI panics if the token URI is not empty. +func assertOnlyEmptyTokenURI(tid grc721.TokenID) { + uri, _ := gnft.TokenURI(tid) + if string(uri) != "" { + panic(addDetailToError( + errCannotSetURI, + ufmt.Sprintf("token id (%s) has already set URI", string(tid)), + )) + } +} diff --git a/contract/r/gnoswap/gnft/gnft_test.gno b/contract/r/gnoswap/gnft/gnft_test.gno new file mode 100644 index 000000000..f754ed2fe --- /dev/null +++ b/contract/r/gnoswap/gnft/gnft_test.gno @@ -0,0 +1,529 @@ +package gnft + +import ( + "std" + "testing" + + "gno.land/p/demo/avl" + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" +) + +const ( + errInvalidTokenId = "invalid token id" +) + +var ( + positionAddr = consts.POSITION_ADDR + positionRealm = std.NewCodeRealm(consts.POSITION_PATH) + + addr01 = testutils.TestAddress("addr01") + addr01Realm = std.NewUserRealm(addr01) + + addr02 = testutils.TestAddress("addr02") + addr02Realm = std.NewUserRealm(addr02) +) + +func TestMetadata(t *testing.T) { + tests := []struct { + name string + fn func() string + expected string + }{ + {"Name()", Name, "GNOSWAP NFT"}, + {"Symbol()", Symbol, "GNFT"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uassert.Equal(t, tt.expected, tt.fn()) + }) + } +} + +func TestTotalSupply(t *testing.T) { + tests := []struct { + name string + setup func() + expected uint64 + }{ + { + name: "initial total supply", + expected: uint64(0), + }, + { + name: "total supply after minting", + setup: func() { + std.TestSetRealm(positionRealm) + Mint(addr01, tid(1)) + Mint(addr01, tid(2)) + }, + expected: uint64(2), + }, + { + name: "total supply after burning", + setup: func() { + std.TestSetRealm(positionRealm) + Burn(tid(2)) + }, + expected: uint64(1), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + uassert.Equal(t, tt.expected, TotalSupply()) + }) + } +} + +func TestBalanceOf(t *testing.T) { + tests := []struct { + name string + addr std.Address + expected uint64 + }{ + {"BalanceOf(addr01)", addr01, uint64(1)}, + {"BalanceOf(addr02)", addr02, uint64(0)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + balance, _ := BalanceOf(tt.addr) + uassert.Equal(t, tt.expected, balance) + }) + } +} + +func TestOwnerOf(t *testing.T) { + tests := []struct { + name string + tokenId uint64 + shouldPanic bool + panicMsg string + expected std.Address + }{ + {"OwnerOf(1)", 1, false, "", addr01}, + {"OwnerOf(500)", 500, false, errInvalidTokenId, addr01}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + OwnerOf(tid(tt.tokenId)) + }) + } else { + ownerAddr, err := OwnerOf(tid(tt.tokenId)) + if err != nil { + uassert.Equal(t, tt.panicMsg, err.Error()) + } else { + uassert.Equal(t, tt.expected, ownerAddr) + } + } + }) + } +} + +func TestIsApprovedForAll(t *testing.T) { + tests := []struct { + name string + setup func() + expected bool + }{ + { + name: "IsApprovedForAll(addr01, addr02)", + expected: false, + }, + { + name: "IsApprovedForAll(addr01, addr02) after setting approval", + setup: func() { + std.TestSetRealm(addr01Realm) + SetApprovalForAll((addr02), true) + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + uassert.Equal(t, tt.expected, IsApprovedForAll((addr01), (addr02))) + }) + } +} + +func TestGetApproved(t *testing.T) { + tests := []struct { + name string + setup func() + expectedAddr std.Address + }{ + { + name: "GetApproved(1)", + expectedAddr: std.Address(""), + }, + { + name: "GetApproved(1) after approving", + setup: func() { + std.TestSetRealm(addr01Realm) + Approve(addr02, tid(1)) + }, + expectedAddr: addr02, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + addr, _ := GetApproved(tid(1)) + uassert.Equal(t, tt.expectedAddr, addr) + }) + } +} + +func TestTransferFrom(t *testing.T) { + resetObject(t) + std.TestSetRealm(positionRealm) + Mint(addr01, tid(1)) + + tests := []struct { + name string + setup func() + callerRealm std.Realm + fromAddr std.Address + toAddr std.Address + tokenIdToTransfer uint64 + shouldPanic bool + panicMsg string + expected std.Address + verifyTokenList func() + }{ + { + name: "transfer non-existent token id", + callerRealm: std.NewUserRealm(addr01), + fromAddr: addr01, + toAddr: addr02, + tokenIdToTransfer: 99, + shouldPanic: true, + panicMsg: "[GNOSWAP-GNFT-001] caller has no permission || caller (g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5) is not the owner or operator of token (99)", + }, + { + name: "transfer token owned by other user without approval", + callerRealm: std.NewUserRealm(addr02), + fromAddr: addr01, + toAddr: addr02, + tokenIdToTransfer: 1, + shouldPanic: true, + panicMsg: "[GNOSWAP-GNFT-001] caller has no permission || caller (g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5) is not the owner or operator of token (1)", + }, + { + name: "transfer token owned by other user with approval", + setup: func() { + std.TestSetRealm(addr01Realm) + Approve((addr02), tid(1)) + }, + callerRealm: std.NewUserRealm(addr02), + fromAddr: addr01, + toAddr: addr02, + tokenIdToTransfer: 1, + verifyTokenList: func() { + uassert.Equal(t, 0, len(mustGetTokenList(addr01))) + uassert.Equal(t, 1, len(mustGetTokenList(addr02))) + }, + }, + { + name: "transfer token owned by caller", + callerRealm: std.NewUserRealm(addr02), + fromAddr: addr02, + toAddr: addr01, + tokenIdToTransfer: 1, + verifyTokenList: func() { + uassert.Equal(t, 1, len(mustGetTokenList(addr01))) + uassert.Equal(t, 0, len(mustGetTokenList(addr02))) + }, + }, + { + name: "transfer from is invalid address", + callerRealm: std.NewUserRealm(addr01), + fromAddr: std.Address(""), + toAddr: addr02, + tokenIdToTransfer: 1, + shouldPanic: true, + panicMsg: "[GNOSWAP-GNFT-004] invalid addresss || ", + }, + { + name: "transfer to is invalid address", + callerRealm: std.NewUserRealm(addr01), + fromAddr: addr01, + toAddr: std.Address("this_is_invalid_address"), + tokenIdToTransfer: 1, + shouldPanic: true, + panicMsg: "[GNOSWAP-GNFT-004] invalid addresss || this_is_invalid_address", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + TransferFrom((tt.fromAddr), (tt.toAddr), tid(tt.tokenIdToTransfer)) + }) + } else { + std.TestSetRealm(tt.callerRealm) + TransferFrom((tt.fromAddr), (tt.toAddr), tid(tt.tokenIdToTransfer)) + tt.verifyTokenList() + } + }) + } +} + +func TestMint(t *testing.T) { + resetObject(t) + + tests := []struct { + name string + callerRealm std.Realm + tokenIdToMint uint64 + addressToMint std.Address + shouldPanic bool + panicMsg string + expected string + verifyTokenList func() + }{ + { + name: "mint without permission", + shouldPanic: true, + panicMsg: "ownable: caller is not owner", + }, + { + name: "mint first nft to addr01", + callerRealm: std.NewCodeRealm(consts.POSITION_PATH), + tokenIdToMint: 1, + addressToMint: addr01, + expected: "1", + verifyTokenList: func() { + uassert.Equal(t, 1, len(mustGetTokenList(addr01))) + }, + }, + { + name: "mint second nft to addr02", + callerRealm: std.NewCodeRealm(consts.POSITION_PATH), + tokenIdToMint: 2, + addressToMint: addr02, + expected: "2", + verifyTokenList: func() { + uassert.Equal(t, 1, len(mustGetTokenList(addr02))) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + Mint((tt.addressToMint), tid(tt.tokenIdToMint)) + }) + + } else { + std.TestSetRealm(tt.callerRealm) + mintedTokenId := Mint((tt.addressToMint), tid(tt.tokenIdToMint)) + uassert.Equal(t, tt.expected, string(mintedTokenId)) + tt.verifyTokenList() + } + }) + } +} + +func TestBurn(t *testing.T) { + tests := []struct { + name string + callerRealm std.Realm + tokenIdToBurn uint64 + shouldPanic bool + panicMsg string + verifyTokenList func() + }{ + { + name: "burn without permission", + tokenIdToBurn: 1, + shouldPanic: true, + panicMsg: "ownable: caller is not owner", + }, + { + name: "burn non-existent token id", + callerRealm: std.NewCodeRealm(consts.POSITION_PATH), + tokenIdToBurn: 99, + shouldPanic: true, + panicMsg: errInvalidTokenId, + }, + { + name: "burn token id(2)", + callerRealm: std.NewCodeRealm(consts.POSITION_PATH), + tokenIdToBurn: 2, + shouldPanic: false, + verifyTokenList: func() { + uassert.Equal(t, 0, len(mustGetTokenList(addr02))) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetRealm(tt.callerRealm) + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + Burn(tid(tt.tokenIdToBurn)) + }) + } else { + uassert.NotPanics(t, func() { + Burn(tid(tt.tokenIdToBurn)) + }) + tt.verifyTokenList() + } + }) + } +} + +func TestSetTokenURI(t *testing.T) { + tests := []struct { + name string + callerRealm std.Realm + tokenId uint64 + shouldPanic bool + panicMsg string + }{ + { + name: "set token uri without permission", + tokenId: 1, + shouldPanic: true, + panicMsg: `[GNOSWAP-GNFT-001] caller has no permission || caller () is not the owner of token (1)`, + }, + { + name: "set token uri of non-minted token id", + tokenId: 99, + shouldPanic: true, + panicMsg: `[GNOSWAP-GNFT-002] cannot set URI || invalid token id (99)`, + }, + { + name: "set token uri of token id(1)", + callerRealm: addr01Realm, + tokenId: 1, + }, + { + name: "set token uri of token id(1) - twice", + callerRealm: addr01Realm, + tokenId: 1, + shouldPanic: true, + panicMsg: "[GNOSWAP-GNFT-002] cannot set URI || token id (1) has already set URI", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetRealm(tt.callerRealm) + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + SetTokenURIByImageURI(tid(tt.tokenId)) + }) + } else { + uassert.NotPanics(t, func() { + SetTokenURIByImageURI(tid(tt.tokenId)) + }) + } + }) + } +} + +func TestTokenURI(t *testing.T) { + resetObject(t) + + tests := []struct { + name string + setup func() + tokenId uint64 + shouldPanic bool + panicMsg string + }{ + { + name: "get token uri of non-minted token id", + tokenId: 99, + shouldPanic: true, + panicMsg: errInvalidTokenId, + }, + { + name: "get token uri of minted token but not set token uri", + setup: func() { + std.TestSetRealm(positionRealm) + Mint((addr01), tid(1)) + }, + tokenId: 1, + shouldPanic: true, + panicMsg: errInvalidTokenId, + }, + { + name: "get token uri of minted token after setting token uri", + setup: func() { + std.TestSetRealm(addr01Realm) + SetTokenURIByImageURI(tid(1)) + }, + tokenId: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + TokenURI(tid(tt.tokenId)) + }) + } else { + uassert.NotEmpty(t, TokenURI(tid(tt.tokenId))) + } + }) + } +} + +func TestSetTokenURILast(t *testing.T) { + resetObject(t) + std.TestSetRealm(positionRealm) + Mint(addr01, tid(1)) + Mint(addr01, tid(2)) // last minted + + t.Run("set token uri last", func(t *testing.T) { + std.TestSetRealm(addr01Realm) + SetTokenURILast() + }) + + t.Run("token uri(2)", func(t *testing.T) { + uassert.NotEmpty(t, TokenURI(tid(2))) + }) +} + +func resetObject(t *testing.T) { + t.Helper() + + gnft = grc721.NewBasicNFT("GNOSWAP NFT", "GNFT") + tokenList = avl.NewTree() +} diff --git a/contract/r/gnoswap/gnft/gno.mod b/contract/r/gnoswap/gnft/gno.mod new file mode 100644 index 000000000..3af907ab9 --- /dev/null +++ b/contract/r/gnoswap/gnft/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/gnft diff --git a/contract/r/gnoswap/gnft/svg_generator.gno b/contract/r/gnoswap/gnft/svg_generator.gno new file mode 100644 index 000000000..49d6e3186 --- /dev/null +++ b/contract/r/gnoswap/gnft/svg_generator.gno @@ -0,0 +1,67 @@ +package gnft + +import ( + b64 "encoding/base64" + "math/rand" + "strings" + + "gno.land/p/demo/ufmt" +) + +var baseTempalte = ` + + + + + + + + + + + + + + + + + + + + + + + +` + +// range for hex color +const charset = "0123456789ABCDEF" + +func genImageURI(r *rand.Rand) string { + imageRaw := genImageRaw(r) + sEnc := b64.StdEncoding.EncodeToString([]byte(imageRaw)) + + return "data:image/svg+xml;base64," + sEnc +} + +func genImageRaw(r *rand.Rand) string { + x1 := 7 + r.Uint64N(7) + y1 := 7 + r.Uint64N(7) + + x2 := 121 + r.Uint64N(6) + y2 := 121 + r.Uint64N(6) + + var color1, color2 strings.Builder + color1.Grow(7) + color2.Grow(7) + color1.WriteByte('#') + color2.WriteByte('#') + + for i := 0; i < 6; i++ { + color1.WriteByte(charset[r.IntN(16)]) + color2.WriteByte(charset[r.IntN(16)]) + } + + randImage := ufmt.Sprintf(baseTempalte, x1, y1, x2, y2, color1.String(), color2.String()) + return randImage +} diff --git a/contract/r/gnoswap/gnft/svg_generator_test.gno b/contract/r/gnoswap/gnft/svg_generator_test.gno new file mode 100644 index 000000000..4ef09cc7d --- /dev/null +++ b/contract/r/gnoswap/gnft/svg_generator_test.gno @@ -0,0 +1,25 @@ +package gnft + +import ( + "math/rand" + "testing" + "time" + + "gno.land/p/demo/uassert" +) + +func TestGenImageURI(t *testing.T) { + seed1 := uint64(time.Now().Unix()) + seed2 := uint64(time.Now().UnixNano()) + pcg := rand.NewPCG(seed1, seed2) + r := rand.New(pcg) + + var uri string + uassert.NotPanics(t, func() { + uri = genImageURI(r) + }) + + expectedUri := `` + + uassert.Equal(t, expectedUri, uri) +} diff --git a/contract/r/gnoswap/gnft/utils.gno b/contract/r/gnoswap/gnft/utils.gno new file mode 100644 index 000000000..d30832e3a --- /dev/null +++ b/contract/r/gnoswap/gnft/utils.gno @@ -0,0 +1,46 @@ +package gnft + +import ( + "std" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/ufmt" +) + +// getPrevAsString returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { + prev := std.PrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// tid converts uint64 to grc721.TokenID. +// +// Input: +// - id: the uint64 to convert +// +// Output: +// - grc721.TokenID: the converted token ID +func tid(id uint64) grc721.TokenID { + return grc721.TokenID(ufmt.Sprintf("%d", id)) +} + +// Exists checks if a token ID exists. +// +// Input: +// - tid: the token ID to check +// +// Output: +// - bool: true if the token ID exists, false otherwise +func Exists(tid grc721.TokenID) bool { + _, err := gnft.OwnerOf(tid) + if err != nil { + return false + } + + return true +} diff --git a/contract/r/gnoswap/gns/_helper_test.gno b/contract/r/gnoswap/gns/_helper_test.gno new file mode 100644 index 000000000..f55abb47c --- /dev/null +++ b/contract/r/gnoswap/gns/_helper_test.gno @@ -0,0 +1,43 @@ +package gns + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + + "gno.land/p/gnoswap/consts" +) + +func resetObject(t *testing.T) { + t.Helper() + + resetGnsTokenObject(t) + resetHalvingRelatedObject(t) + + height := std.GetHeight() + lastMintedHeight = height + startHeight = height + startTimestamp = time.Now().Unix() +} + +func resetGnsTokenObject(t *testing.T) { + t.Helper() + + Token, privateLedger = grc20.NewToken("Gnoswap", "GNS", 6) + UserTeller = Token.CallerTeller() + owner = ownable.NewWithAddress(consts.ADMIN) + privateLedger.Mint(owner.Owner(), INITIAL_MINT_AMOUNT) +} + +func resetHalvingRelatedObject(t *testing.T) { + t.Helper() + + startHeight = std.GetHeight() + startTimestamp = time.Now().Unix() + + emissionState = GetEmissionState() + setEndTimestamp(startTimestamp + consts.TIMESTAMP_YEAR*HALVING_END_YEAR) +} diff --git a/contract/r/gnoswap/gns/errors.gno b/contract/r/gnoswap/gns/errors.gno new file mode 100644 index 000000000..de2dc57a4 --- /dev/null +++ b/contract/r/gnoswap/gns/errors.gno @@ -0,0 +1,18 @@ +package gns + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errInvalidYear = errors.New("[GNOSWAP-GNS-001] invalid year") + errTooManyEmission = errors.New("[GNOSWAP-GNS-002] too many emission reward") + errCallbackEmissionChangeIsNil = errors.New("[GNOSWAP-GNS-003] callback emission change is nil") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/gns/gno.mod b/contract/r/gnoswap/gns/gno.mod new file mode 100644 index 000000000..67209d1d3 --- /dev/null +++ b/contract/r/gnoswap/gns/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/gns diff --git a/contract/r/gnoswap/gns/gns.gno b/contract/r/gnoswap/gns/gns.gno new file mode 100644 index 000000000..4b73ec783 --- /dev/null +++ b/contract/r/gnoswap/gns/gns.gno @@ -0,0 +1,312 @@ +package gns + +import ( + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/grc20reg" + + "gno.land/p/gnoswap/consts" +) + +const ( + MAXIMUM_SUPPLY = uint64(1_000_000_000_000_000) + INITIAL_MINT_AMOUNT = uint64(100_000_000_000_000) + MAX_EMISSION_AMOUNT = uint64(900_000_000_000_000) // MAXIMUM_SUPPLY - INITIAL_MINT_AMOUNT +) + +var ( + owner *ownable.Ownable + Token *grc20.Token + privateLedger *grc20.PrivateLedger + UserTeller grc20.Teller + + leftEmissionAmount uint64 + mintedEmissionAmount uint64 + lastMintedHeight int64 + + burnAmount uint64 +) + +func init() { + owner = ownable.NewWithAddress(consts.ADMIN) + Token, privateLedger = grc20.NewToken("Gnoswap", "GNS", 6) + UserTeller = Token.CallerTeller() + + privateLedger.Mint(owner.Owner(), INITIAL_MINT_AMOUNT) + getter := func() *grc20.Token { return Token } + grc20reg.Register(getter, "") + + // Initial amount set to 900_000_000_000_000 (MAXIMUM_SUPPLY - INITIAL_MINT_AMOUNT). + // leftEmissionAmount will decrease as tokens are minted. + setLeftEmissionAmount(MAX_EMISSION_AMOUNT) + setMintedEmissionAmount(uint64(0)) + setLastMintedHeight(std.GetHeight()) + burnAmount = uint64(0) +} + +func TotalSupply() uint64 { + return Token.TotalSupply() +} + +func GetName() string { + return Token.GetName() +} + +func GetSymbol() string { + return Token.GetSymbol() +} + +func GetDecimals() uint { + return Token.GetDecimals() +} + +func KnownAccounts() int { + return Token.KnownAccounts() +} + +func BalanceOf(owner std.Address) uint64 { + return Token.BalanceOf(owner) +} + +func Allowance(owner, spender std.Address) uint64 { + return Token.Allowance(owner, spender) +} + +func MintGns(address std.Address) uint64 { + lastGNSMintedHeight := GetLastMintedHeight() + currentHeight := std.GetHeight() + + // skip minting process if following conditions are met + // - if gns for current block is already minted + // - if last minted height is same or later than emission end height + if lastGNSMintedHeight == currentHeight || lastGNSMintedHeight >= GetEndHeight() { + return 0 + } + + assertShouldNotBeHalted() + assertCallerIsEmission() + + // calculate gns amount to mint + amountToMint := calculateAmountToMint(lastGNSMintedHeight+1, currentHeight) + + // update + setLastMintedHeight(currentHeight) + setMintedEmissionAmount(GetMintedEmissionAmount() + amountToMint) + setLeftEmissionAmount(GetLeftEmissionAmount() - amountToMint) + + // mint calculated amount to address + err := privateLedger.Mint(address, amountToMint) + if err != nil { + panic(err.Error()) + } + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "MintGNS", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "mintedBlockHeight", formatInt(currentHeight), + "mintedGNSAmount", formatUint(amountToMint), + "accumMintedGNSAmount", formatUint(GetMintedEmissionAmount()), + "accumLeftMintGNSAmount", formatUint(GetLeftEmissionAmount()), + ) + + return amountToMint +} + +func Burn(from std.Address, amount uint64) { + owner.AssertCallerIsOwner() + checkErr(privateLedger.Burn(from, amount)) + + burnAmount += amount + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "Burn", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "burnedBlockHeight", formatInt(std.GetHeight()), + "burnFrom", from.String(), + "burnedGNSAmount", formatUint(amount), + "accumBurnedGNSAmount", formatUint(GetBurnAmount()), + ) +} + +func Transfer(to std.Address, amount uint64) { + checkErr(UserTeller.Transfer(to, amount)) +} + +func Approve(spender std.Address, amount uint64) { + checkErr(UserTeller.Approve(spender, amount)) +} + +func TransferFrom(from, to std.Address, amount uint64) { + checkErr(UserTeller.TransferFrom(from, to, amount)) +} + +func Render(path string) string { + parts := strings.Split(path, "/") + c := len(parts) + + switch { + case path == "": + return Token.RenderHome() + case c == 2 && parts[0] == "balance": + owner := std.Address(parts[1]) + balance := Token.BalanceOf(owner) + return ufmt.Sprintf("%d\n", balance) + default: + return "404\n" + } +} + +func checkErr(err error) { + if err != nil { + panic(err.Error()) + } +} + +// calculateAmountToMint calculates the amount of gns to mint +// It calculates the amount of gns to mint for each halving year for block range. +// It also handles the left emission amount if the current block range includes halving year end block. +func calculateAmountToMint(fromHeight, toHeight int64) uint64 { + // if toHeight is greater than emission end height, set toHeight to emission end height + endH := GetEndHeight() + if toHeight > endH { + toHeight = endH + } + + if fromHeight > toHeight { + return 0 + } + + fromYear := GetHalvingYearByHeight(fromHeight) + toYear := GetHalvingYearByHeight(toHeight) + + totalAmountToMint := uint64(0) + + curFrom := fromHeight + + for year := fromYear; year <= toYear; year++ { + yearEndHeight := GetHalvingYearEndBlock(year) + mintUntilHeight := i64Min(yearEndHeight, toHeight) + + // how many blocks to calculate + blocks := uint64(mintUntilHeight - curFrom + 1) + if blocks <= 0 { + break + } + + // amount of gns to mint for each block for current year + singleBlockAmount := GetAmountPerBlockPerHalvingYear(year) + + // amount of gns to mint for current year + yearAmountToMint := singleBlockAmount * blocks + + // if last block of halving year, handle left emission amount + if mintUntilHeight >= yearEndHeight { + leftover := handleLeftEmissionAmount(year, yearAmountToMint) + yearAmountToMint += leftover + } + + totalAmountToMint += yearAmountToMint + + setHalvingYearMintAmount(year, GetHalvingYearMintAmount(year)+yearAmountToMint) + setHalvingYearLeftAmount(year, GetHalvingYearLeftAmount(year)-yearAmountToMint) + + std.Emit( + "CalculateAmountToMint", + "fromHeight", formatInt(curFrom), + "toHeight", formatInt(mintUntilHeight), + "year", formatInt(year), + "singleBlockAmount", formatUint(singleBlockAmount), + ) + + // update fromHeight for next year (if necessary) + curFrom = mintUntilHeight + 1 + if curFrom > toHeight { + break + } + } + + assertTooManyEmission(totalAmountToMint) + return totalAmountToMint +} + +// isLastBlockOfHalvingYear returns true if the current block is the last block of a halving year. +func isLastBlockOfHalvingYear(height int64) bool { + year := GetHalvingYearByHeight(height) + lastBlock := GetHalvingYearEndBlock(year) + + return height == lastBlock +} + +// handleLeftEmissionAmount handles the left emission amount for a halving year. +// It calculates the left emission amount by subtracting the halving year mint amount from the halving year amount. +func handleLeftEmissionAmount(year int64, amount uint64) uint64 { + return GetHalvingYearLeftAmount(year) - amount +} + +// skipIfSameHeight returns true if the current block height is the same as the last minted height. +// This prevents multiple gns minting inside the same block. +func skipIfSameHeight(lastMintedHeight, currentHeight int64) bool { + return lastMintedHeight == currentHeight +} + +// skipIfEmissionEnded returns true if the emission has ended. +func skipIfEmissionEnded(height int64) bool { + if isEmissionEnded(height) { + return true + } + + return false +} + +// GetLastMintedHeight returns the last block height that gns was minted. +func GetLastMintedHeight() int64 { + return lastMintedHeight +} + +// GetLeftEmissionAmount returns the amount of GNS can be minted. +func GetLeftEmissionAmount() uint64 { + return leftEmissionAmount +} + +// GetBurnAmount returns the amount of GNS that has been burned. +func GetBurnAmount() uint64 { + return burnAmount +} + +// GetMintedEmissionAmount returns the amount of GNS that has been minted by the emission contract. +// It does not include initial minted amount. +func GetMintedEmissionAmount() uint64 { + return mintedEmissionAmount +} + +func setLastMintedHeight(height int64) { + lastMintedHeight = height +} + +func setLeftEmissionAmount(amount uint64) { + leftEmissionAmount = amount +} + +func setMintedEmissionAmount(amount uint64) { + mintedEmissionAmount = amount +} + +// assertTooManyEmission asserts if the amount of gns to mint is too many. +// It checks if the amount of gns to mint is greater than the left emission amount or the total emission amount. +func assertTooManyEmission(amount uint64) { + if (amount + GetMintedEmissionAmount()) > MAX_EMISSION_AMOUNT { + panic(addDetailToError( + errTooManyEmission, + ufmt.Sprintf("amount: %d", amount), + )) + } +} diff --git a/contract/r/gnoswap/gns/gns_test.gno b/contract/r/gnoswap/gns/gns_test.gno new file mode 100644 index 000000000..9f449447c --- /dev/null +++ b/contract/r/gnoswap/gns/gns_test.gno @@ -0,0 +1,315 @@ +package gns + +import ( + "fmt" + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" +) + +const ( + // gnoVM test context default height + // ref: https://github.com/gnolang/gno/blob/a85a53d5b38f0a21d66262a823a8b07f4f836b68/gnovm/pkg/test/test.go#L31-L32 + GNO_VM_DEFAULT_HEIGHT = int64(123) +) + +var ( + emissionRealm = std.NewCodeRealm(consts.EMISSION_PATH) + adminRealm = std.NewUserRealm(consts.ADMIN) +) + +var ( + alice = testutils.TestAddress("alice") + bob = testutils.TestAddress("bob") +) + +func TestGetName(t *testing.T) { + uassert.Equal(t, "Gnoswap", GetName()) +} + +func TestGetSymbol(t *testing.T) { + uassert.Equal(t, "GNS", GetSymbol()) +} + +func TestGetDecimals(t *testing.T) { + uassert.Equal(t, uint(6), GetDecimals()) +} + +func TestKnownAccounts(t *testing.T) { + uassert.Equal(t, int(1), KnownAccounts()) +} + +func TestTotalSupply(t *testing.T) { + uassert.Equal(t, INITIAL_MINT_AMOUNT, TotalSupply()) +} + +func TestBalanceOf(t *testing.T) { + uassert.Equal(t, INITIAL_MINT_AMOUNT, BalanceOf(consts.ADMIN)) +} + +func TestAssertTooManyEmission(t *testing.T) { + tests := []struct { + name string + amount uint64 + shouldPanic bool + panicMsg string + }{ + { + name: "should panic if emission amount is too large", + amount: MAXIMUM_SUPPLY, + shouldPanic: true, + panicMsg: "[GNOSWAP-GNS-002] too many emission reward || amount: 1000000000000000", + }, + { + name: "should not panic if emission amount is not too large", + amount: 123, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + assertTooManyEmission(tt.amount) + }) + } else { + uassert.NotPanics(t, func() { assertTooManyEmission(tt.amount) }) + } + }) + } +} + +func TestIsLastBlockOfHalvingYear(t *testing.T) { + tests := make([]struct { + name string + height int64 + want bool + }, 0, 24) + + for i := HALVING_START_YEAR; i <= HALVING_END_YEAR; i++ { + tests = append(tests, struct { + name string + height int64 + want bool + }{ + name: fmt.Sprintf("last block of halving year %d", i), + height: GetHalvingYearEndBlock(i), + want: true, + }) + + tests = append(tests, struct { + name string + height int64 + want bool + }{ + name: fmt.Sprintf("not last block of halving year %d", i), + height: GetHalvingYearEndBlock(i) - 1, + want: false, + }) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uassert.Equal(t, tt.want, isLastBlockOfHalvingYear(tt.height)) + }) + } +} + +func TestHandleLeftEmissionAmount(t *testing.T) { + tests := make([]struct { + name string + year int64 + amount uint64 + want uint64 + }, 0, 24) + + for i := int64(1); i <= 12; i++ { + tests = append(tests, struct { + name string + year int64 + amount uint64 + want uint64 + }{ + name: fmt.Sprintf("handle left emission amount for year %d, non minted", i), + year: i, + amount: 0, + want: GetHalvingYearMaxAmount(i), + }) + + tests = append(tests, struct { + name string + year int64 + amount uint64 + want uint64 + }{ + name: fmt.Sprintf("handle left emission amount for year %d, minted", i), + year: i, + amount: uint64(123456), + want: GetHalvingYearMaxAmount(i) - uint64(123456), + }) + } +} + +func TestSkipIfSameHeight(t *testing.T) { + t.Run("should skip if height is same", func(t *testing.T) { + uassert.True(t, skipIfSameHeight(1, 1)) + }) + + t.Run("should not skip if height is different", func(t *testing.T) { + uassert.False(t, skipIfSameHeight(1, 2)) + }) +} + +func TestSkipIfEmissionEnded(t *testing.T) { + t.Run("should skip if emission has ended", func(t *testing.T) { + uassert.True(t, skipIfEmissionEnded(GetEndHeight()+1)) + }) + + t.Run("should not skip if emission has not ended", func(t *testing.T) { + uassert.False(t, skipIfEmissionEnded(std.GetHeight())) + }) +} + +func TestGetterSetter(t *testing.T) { + t.Run("last minted height", func(t *testing.T) { + value := int64(123) + setLastMintedHeight(value) + uassert.Equal(t, value, GetLastMintedHeight()) + }) + + t.Run("left emission amount", func(t *testing.T) { + value := uint64(0) + setLeftEmissionAmount(value) + uassert.Equal(t, value, GetLeftEmissionAmount()) + }) +} + +func TestGrc20Methods(t *testing.T) { + tests := []struct { + name string + fn func() + shouldPanic bool + panicMsg string + }{ + { + name: "TotalSupply", + fn: func() { + uassert.Equal(t, INITIAL_MINT_AMOUNT, TotalSupply()) + }, + }, + { + name: "BalanceOf(admin)", + fn: func() { + uassert.Equal(t, INITIAL_MINT_AMOUNT, BalanceOf(consts.ADMIN)) + }, + }, + { + name: "BalanceOf(alice)", + fn: func() { + uassert.Equal(t, uint64(0), BalanceOf(alice)) + }, + }, + { + name: "Allowance(admin, alice)", + fn: func() { + uassert.Equal(t, uint64(0), Allowance(consts.ADMIN, alice)) + }, + }, + { + name: "MintGns success", + fn: func() { + std.TestSetRealm(emissionRealm) + MintGns(consts.ADMIN) + }, + }, + { + name: "MintGns without permission should panic", + fn: func() { + std.TestSkipHeights(1) + MintGns(consts.ADMIN) + }, + shouldPanic: true, + panicMsg: `caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission`, + }, + { + name: "Burn success", + fn: func() { + std.TestSetRealm(adminRealm) + Burn(consts.ADMIN, uint64(1)) + }, + }, + { + name: "Burn without permission should panic", + fn: func() { + Burn(consts.ADMIN, uint64(1)) + }, + shouldPanic: true, + panicMsg: `ownable: caller is not owner`, + }, + { + name: "Transfer success", + fn: func() { + std.TestSetRealm(adminRealm) + Transfer(alice, uint64(1)) + }, + }, + { + name: "Transfer without enough balance should panic", + fn: func() { + std.TestSetRealm(std.NewUserRealm(alice)) + Transfer(bob, uint64(1)) + }, + shouldPanic: true, + panicMsg: `insufficient balance`, + }, + { + name: "Transfer to self should panic", + fn: func() { + std.TestSetRealm(adminRealm) + Transfer(consts.ADMIN, uint64(1)) + }, + shouldPanic: true, + panicMsg: `cannot send transfer to self`, + }, + { + name: "TransferFrom success", + fn: func() { + // approve first + std.TestSetRealm(adminRealm) + Approve(alice, uint64(1)) + + // alice transfer admin's balance to bob + std.TestSetRealm(std.NewUserRealm(alice)) + TransferFrom(consts.ADMIN, bob, uint64(1)) + }, + }, + { + name: "TransferFrom without enough allowance should panic", + fn: func() { + std.TestSetRealm(adminRealm) + Approve(alice, uint64(1)) + + std.TestSetRealm(std.NewUserRealm(alice)) + TransferFrom(consts.ADMIN, bob, uint64(2)) + }, + shouldPanic: true, + panicMsg: `insufficient allowance`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetGnsTokenObject(t) + + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, tt.fn) + } else { + uassert.NotPanics(t, func() { tt.fn() }) + } + }) + } +} diff --git a/contract/r/gnoswap/gns/halving.gno b/contract/r/gnoswap/gns/halving.gno new file mode 100644 index 000000000..36171d0ad --- /dev/null +++ b/contract/r/gnoswap/gns/halving.gno @@ -0,0 +1,632 @@ +package gns + +import ( + "std" + "strconv" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/json" + "gno.land/p/gnoswap/consts" +) + +// init 12 years halving tier block +/* + NOTE: assume block will be created every 1 second by default + 1 second = 1 block + 1 minute = 60 block + 1 hour = 3600 block + 1 day = 86400 block + (365 days) 1 year = 31536000 block +*/ + +const ( + HALVING_START_YEAR = int64(1) + HALVING_END_YEAR = int64(12) +) + +var ( + HALVING_AMOUNTS_PER_YEAR = [HALVING_END_YEAR]uint64{ + 18_750_000_000_000 * 12, // Year 1: 225000000000000 + 18_750_000_000_000 * 12, // Year 2: 225000000000000 + 9_375_000_000_000 * 12, // Year 3: 112500000000000 + 9_375_000_000_000 * 12, // Year 4: 112500000000000 + 4_687_500_000_000 * 12, // Year 5: 56250000000000 + 4_687_500_000_000 * 12, // Year 6: 56250000000000 + 2_343_750_000_000 * 12, // Year 7: 28125000000000 + 2_343_750_000_000 * 12, // Year 8: 28125000000000 + 1_171_875_000_000 * 12, // Year 9: 14062500000000 + 1_171_875_000_000 * 12, // Year 10: 14062500000000 + 1_171_875_000_000 * 12, // Year 11: 14062500000000 + 1_171_875_000_000 * 12, // Year 12: 14062500000000 + } +) + +var ( + blockPerYear = consts.TIMESTAMP_YEAR / consts.BLOCK_GENERATION_INTERVAL + blockPerDay = consts.TIMESTAMP_DAY / consts.BLOCK_GENERATION_INTERVAL + + avgBlockTimeMs int64 = consts.SECOND_IN_MILLISECOND * consts.BLOCK_GENERATION_INTERVAL + perBlockMint = avl.NewTree() // height => uint64 +) + +type HalvingData struct { + startBlockHeight []int64 + endBlockHeight []int64 + startTimestamp []int64 + maxAmount []uint64 + mintedAmount []uint64 + leftAmount []uint64 + accumAmount []uint64 + amountPerBlock []uint64 +} + +func (h *HalvingData) getStartBlockHeight(year int64) int64 { + if year == 0 { + return 0 + } + return h.startBlockHeight[year-1] +} + +func (h *HalvingData) setStartBlockHeight(year int64, height int64) { + assertValidYear(year) + h.startBlockHeight[year-1] = height +} + +func (h *HalvingData) getEndBlockHeight(year int64) int64 { + if year == 0 { + return 0 + } + return h.endBlockHeight[year-1] +} + +func (h *HalvingData) setEndBlockHeight(year int64, height int64) { + assertValidYear(year) + h.endBlockHeight[year-1] = height +} + +func (h *HalvingData) getStartTimestamp(year int64) int64 { + if year == 0 { + return 0 + } + return h.startTimestamp[year-1] +} + +func (h *HalvingData) setStartTimestamp(year int64, timestamp int64) { + assertValidYear(year) + h.startTimestamp[year-1] = timestamp +} + +func (h *HalvingData) getMaxAmount(year int64) uint64 { + if year == 0 { + return 0 + } + return h.maxAmount[year-1] +} + +func (h *HalvingData) setMaxAmount(year int64, amount uint64) { + assertValidYear(year) + h.maxAmount[year-1] = amount +} + +func (h *HalvingData) getMintedAmount(year int64) uint64 { + if year == 0 { + return 0 + } + return h.mintedAmount[year-1] +} + +func (h *HalvingData) setMintedAmount(year int64, amount uint64) { + assertValidYear(year) + h.mintedAmount[year-1] = amount +} + +func (h *HalvingData) getLeftAmount(year int64) uint64 { + if year == 0 { + return 0 + } + return h.leftAmount[year-1] +} + +func (h *HalvingData) setLeftAmount(year int64, amount uint64) { + assertValidYear(year) + h.leftAmount[year-1] = amount +} + +func (h *HalvingData) getAccumAmount(year int64) uint64 { + if year == 0 { + return 0 + } + return h.accumAmount[year-1] +} + +func (h *HalvingData) setAccumAmount(year int64, amount uint64) { + assertValidYear(year) + h.accumAmount[year-1] = amount +} + +func (h *HalvingData) addAccumAmount(year int64, amount uint64) { + assertValidYear(year) + h.accumAmount[year-1] += amount +} + +func (h *HalvingData) getAmountPerBlock(year int64) uint64 { + if year == 0 { + return 0 + } + return h.amountPerBlock[year-1] +} + +func (h *HalvingData) setAmountPerBlock(year int64, amount uint64) { + assertValidYear(year) + h.amountPerBlock[year-1] = amount +} + +func (h *HalvingData) initialize(startHeight, startTimestamp int64) { + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + // set max emission amount per year + // each year can not mint more than this amount + currentYearMaxAmount := GetHalvingAmountsPerYear(year) + h.setMaxAmount(year, currentYearMaxAmount) + + if year == HALVING_START_YEAR { + h.setAccumAmount(year, currentYearMaxAmount) + h.setStartBlockHeight(year, startHeight) + h.setEndBlockHeight(year, startHeight+(blockPerYear*year)-1) + } else { + // accumulate amount until current year, is the sum of current year max amount and accumulated amount until previous year + h.setAccumAmount(year, currentYearMaxAmount+h.getAccumAmount(year-1)) + + // start block of current year, is the next block of previous year of end block + h.setStartBlockHeight(year, h.getEndBlockHeight(year-1)+1) + + // end block of current year, is sum of start block and block per year + h.setEndBlockHeight(year, h.getStartBlockHeight(year)+blockPerYear-1) + } + + h.setStartTimestamp(year, startTimestamp+(consts.TIMESTAMP_YEAR*(year-1))) + + amountPerDay := currentYearMaxAmount / consts.DAY_PER_YEAR + amountPerBlock := amountPerDay / uint64(blockPerDay) + h.setAmountPerBlock(year, uint64(amountPerBlock)) + h.setMintedAmount(year, uint64(0)) + h.setLeftAmount(year, currentYearMaxAmount) + } +} + +var callbackEmissionChange func(amount uint64) + +func SetCallbackEmissionChange(callback func(amount uint64)) { + if callbackEmissionChange != nil { + panic("callbackEmissionChange already set") + } + callbackEmissionChange = callback +} + +// endHeight MUST be smaller than the current height(we don't read future halving blocks with this function) +func GetHalvingBlocksInRange(startHeight, endHeight int64) ([]int64, []uint64) { + halvingData := GetEmissionState().getHalvingData() + halvingBlocks := []int64{} + halvingEmissions := []uint64{} + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + startBlock := halvingData.getStartBlockHeight(year) + if startBlock < startHeight { + continue + } + if endHeight < startBlock { + break + } + halvingBlocks = append(halvingBlocks, startBlock) + halvingEmissions = append(halvingEmissions, halvingData.getAmountPerBlock(year)) + } + return halvingBlocks, halvingEmissions +} + +// height MUST be smaller than the current height(we don't read future emissions) +func GetEmission() uint64 { + halvingData := GetEmissionState().getHalvingData() + height := std.GetHeight() + + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + startBlock := halvingData.getStartBlockHeight(year) + endBlock := halvingData.getEndBlockHeight(year) + if startBlock <= height && height < endBlock { + return halvingData.getAmountPerBlock(year) + } + } + + return 0 +} + +type EmissionState struct { + startHeight int64 + startTimestamp int64 + endTimestamp int64 + halvingData HalvingData +} + +func GetEmissionState() *EmissionState { + if emissionState == nil { + emissionState = NewEmissionState() + emissionState.initializeHalvingData() + } + return emissionState +} + +func NewEmissionState() *EmissionState { + now := time.Now().Unix() + consts.BLOCK_GENERATION_INTERVAL + emissionEndTime := now + consts.TIMESTAMP_YEAR*HALVING_END_YEAR + + return &EmissionState{ + startHeight: std.GetHeight() + 1, + startTimestamp: now, + endTimestamp: emissionEndTime, + halvingData: HalvingData{ + startBlockHeight: make([]int64, HALVING_END_YEAR), + endBlockHeight: make([]int64, HALVING_END_YEAR), + startTimestamp: make([]int64, HALVING_END_YEAR), + maxAmount: make([]uint64, HALVING_END_YEAR), + mintedAmount: make([]uint64, HALVING_END_YEAR), + leftAmount: make([]uint64, HALVING_END_YEAR), + accumAmount: make([]uint64, HALVING_END_YEAR), + amountPerBlock: make([]uint64, HALVING_END_YEAR), + }, + } +} + +func (e *EmissionState) getStartHeight() int64 { + return e.startHeight +} + +func (e *EmissionState) setStartHeight(height int64) { + e.startHeight = height +} + +func (e *EmissionState) getStartTimestamp() int64 { + return e.startTimestamp +} + +func (e *EmissionState) setStartTimestamp(timestamp int64) { + e.startTimestamp = timestamp +} + +func (e *EmissionState) getEndTimestamp() int64 { + return e.endTimestamp +} + +func (e *EmissionState) setEndTimestamp(timestamp int64) { + e.endTimestamp = timestamp +} + +func (e *EmissionState) getHalvingData() HalvingData { + return e.halvingData +} + +func (e *EmissionState) setHalvingData(data HalvingData) { + e.halvingData = data +} + +// initializeHalvingData initializes the halving data +// it should be called only once, so we call this in init() +func (e *EmissionState) initializeHalvingData() { + halvingData := e.getHalvingData() + halvingData.initialize(e.getStartHeight(), e.getStartTimestamp()) + e.setHalvingData(halvingData) +} + +var emissionState *EmissionState + +func init() { + emissionState = GetEmissionState() +} + +func GetAvgBlockTimeInMs() int64 { + return avgBlockTimeMs +} + +// SetAvgBlockTimeInMsByAdmin sets the average block time in millisecond. +func SetAvgBlockTimeInMsByAdmin(ms int64) { + assertCallerIsAdmin() + + prevAvgBlockTimeInMs := GetAvgBlockTimeInMs() + setAvgBlockTimeInMs(ms) + + halvingData := GetEmissionState().getHalvingData() + prevAddr, prevPkgPath := getPrev() + std.Emit( + "SetAvgBlockTimeInMsByAdmin", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevAvgBlockTimeMs", formatInt(prevAvgBlockTimeInMs), + "newAvgBlockTimeMs", formatInt(ms), + "endBlock1Year", formatInt(halvingData.getEndBlockHeight(1)), + "endBlock2Year", formatInt(halvingData.getEndBlockHeight(2)), + "endBlock3Year", formatInt(halvingData.getEndBlockHeight(3)), + "endBlock4Year", formatInt(halvingData.getEndBlockHeight(4)), + "endBlock5Year", formatInt(halvingData.getEndBlockHeight(4)), + "endBlock6Year", formatInt(halvingData.getEndBlockHeight(5)), + "endBlock7Year", formatInt(halvingData.getEndBlockHeight(6)), + "endBlock8Year", formatInt(halvingData.getEndBlockHeight(7)), + "endBlock9Year", formatInt(halvingData.getEndBlockHeight(8)), + "endBlock10Year", formatInt(halvingData.getEndBlockHeight(9)), + "endBlock11Year", formatInt(halvingData.getEndBlockHeight(10)), + "endBlock12Year", formatInt(halvingData.getEndBlockHeight(11)), + ) +} + +// SetAvgBlockTimeInMs sets the average block time in millisecond. +// Only governance contract can execute this function via proposal +func SetAvgBlockTimeInMs(ms int64) { + assertCallerIsGovernance() + + prevAvgBlockTimeInMs := GetAvgBlockTimeInMs() + setAvgBlockTimeInMs(ms) + + halvingData := GetEmissionState().getHalvingData() + prevAddr, prevPkgPath := getPrev() + std.Emit( + "SetAvgBlockTimeInMs", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevAvgBlockTimeMs", formatInt(prevAvgBlockTimeInMs), + "newAvgBlockTimeMs", formatInt(ms), + "endBlock1Year", formatInt(halvingData.getEndBlockHeight(1)), + "endBlock2Year", formatInt(halvingData.getEndBlockHeight(2)), + "endBlock3Year", formatInt(halvingData.getEndBlockHeight(3)), + "endBlock4Year", formatInt(halvingData.getEndBlockHeight(4)), + "endBlock5Year", formatInt(halvingData.getEndBlockHeight(4)), + "endBlock6Year", formatInt(halvingData.getEndBlockHeight(5)), + "endBlock7Year", formatInt(halvingData.getEndBlockHeight(6)), + "endBlock8Year", formatInt(halvingData.getEndBlockHeight(7)), + "endBlock9Year", formatInt(halvingData.getEndBlockHeight(8)), + "endBlock10Year", formatInt(halvingData.getEndBlockHeight(9)), + "endBlock11Year", formatInt(halvingData.getEndBlockHeight(10)), + "endBlock12Year", formatInt(halvingData.getEndBlockHeight(11)), + ) +} + +func setAvgBlockTimeInMs(ms int64) { + assertShouldNotBeHalted() + + now := time.Now().Unix() + height := std.GetHeight() + + // update block per year + yearByMilliSec := secToMs(consts.TIMESTAMP_YEAR) + blockPerYear = yearByMilliSec / ms + + // get the halving year and end timestamp of current time + currentYear, currentYearEndTimestamp := getHalvingYearAndEndTimestamp(now) + + // how much time left for current halving year + timeLeft := currentYearEndTimestamp - now + timeLeftMs := secToMs(timeLeft) + + // how many block left for current halving year + blockLeft := timeLeftMs / ms + // how many reward left for current halving year + minted := GetMintedEmissionAmount() + amountLeft := GetHalvingYearAccuAmount(currentYear) - minted + + // how much reward should be minted per block for current halving year + adjustedAmountPerBlock := amountLeft / uint64(blockLeft) + // update it + setAmountPerBlockPerHalvingYear(currentYear, adjustedAmountPerBlock) + if callbackEmissionChange == nil { + panic(errCallbackEmissionChangeIsNil.Error()) + } + callbackEmissionChange(adjustedAmountPerBlock) + + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + if year < currentYear { + // pass past halving years + continue + } + + yearEndTimestamp := GetHalvingYearTimestamp(year) + consts.TIMESTAMP_YEAR + timeLeftForYear := yearEndTimestamp - now + numBlock := (timeLeftForYear * consts.SECOND_IN_MILLISECOND) / ms + yearEndHeight := height + numBlock + + if year == currentYear { + // for current year, update only end block + setHalvingYearEndBlock(year, yearEndHeight) + } else { + // update start block + prevYearEnd := GetHalvingYearEndBlock(year - 1) + nextYearStart := prevYearEnd + 1 + nextYearEnd := nextYearStart + blockPerYear + + setHalvingYearStartBlock(year, nextYearStart) + setHalvingYearEndBlock(year, nextYearEnd) + } + } + + avgBlockTimeMs = ms + +} + +// GetAmountByHeight returns the amount of gns to mint by height +func GetAmountByHeight(height int64) uint64 { + if isEmissionEnded(height) { + return 0 + } + + halvingYear := GetHalvingYearByHeight(height) + return GetAmountPerBlockPerHalvingYear(halvingYear) +} + +// GetHalvingYearByHeight returns the halving year by height +func GetHalvingYearByHeight(height int64) int64 { + if isEmissionEnded(height) { + return 0 + } + + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + endBlock := GetHalvingYearEndBlock(year) + if height <= endBlock { + return year + } + } + + return 0 +} + +func GetHalvingYearStartBlock(year int64) int64 { + halvingData := GetEmissionState().getHalvingData() + return halvingData.getStartBlockHeight(year) +} + +func setHalvingYearStartBlock(year int64, block int64) { + halvingData := GetEmissionState().getHalvingData() + halvingData.setStartBlockHeight(year, block) + GetEmissionState().setHalvingData(halvingData) +} + +func GetHalvingYearEndBlock(year int64) int64 { + halvingData := GetEmissionState().getHalvingData() + return halvingData.getEndBlockHeight(year) +} + +func setHalvingYearEndBlock(year int64, block int64) { + halvingData := GetEmissionState().getHalvingData() + halvingData.setEndBlockHeight(year, block) + GetEmissionState().setHalvingData(halvingData) +} + +func GetHalvingYearTimestamp(year int64) int64 { + halvingData := GetEmissionState().getHalvingData() + return halvingData.getStartTimestamp(year) +} + +func setHalvingYearTimestamp(year int64, timestamp int64) { + halvingData := GetEmissionState().getHalvingData() + halvingData.setStartTimestamp(year, timestamp) + GetEmissionState().setHalvingData(halvingData) +} + +func GetHalvingYearMaxAmount(year int64) uint64 { + halvingData := GetEmissionState().getHalvingData() + return halvingData.getMaxAmount(year) +} + +func setHalvingYearMaxAmount(year int64, amount uint64) { + halvingData := GetEmissionState().getHalvingData() + halvingData.setMaxAmount(year, amount) + GetEmissionState().setHalvingData(halvingData) +} + +func GetHalvingYearMintAmount(year int64) uint64 { + halvingData := GetEmissionState().getHalvingData() + return halvingData.getMintedAmount(year) +} + +func setHalvingYearMintAmount(year int64, amount uint64) { + halvingData := GetEmissionState().getHalvingData() + halvingData.setMintedAmount(year, amount) + GetEmissionState().setHalvingData(halvingData) +} + +func GetHalvingYearLeftAmount(year int64) uint64 { + halvingData := GetEmissionState().getHalvingData() + return halvingData.getLeftAmount(year) +} + +func setHalvingYearLeftAmount(year int64, amount uint64) { + halvingData := GetEmissionState().getHalvingData() + halvingData.setLeftAmount(year, amount) + GetEmissionState().setHalvingData(halvingData) +} + +func GetHalvingYearAccuAmount(year int64) uint64 { + halvingData := GetEmissionState().getHalvingData() + return halvingData.getAccumAmount(year) +} + +func setHalvingYearAccuAmount(year int64, amount uint64) { + halvingData := GetEmissionState().getHalvingData() + halvingData.setAccumAmount(year, amount) + GetEmissionState().setHalvingData(halvingData) +} + +func GetAmountPerBlockPerHalvingYear(year int64) uint64 { + halvingData := GetEmissionState().getHalvingData() + return halvingData.getAmountPerBlock(year) +} + +func setAmountPerBlockPerHalvingYear(year int64, amount uint64) { + halvingData := GetEmissionState().getHalvingData() + halvingData.setAmountPerBlock(year, amount) + GetEmissionState().setHalvingData(halvingData) +} + +func GetHalvingAmountsPerYear(year int64) uint64 { + return HALVING_AMOUNTS_PER_YEAR[year-1] +} + +func GetEndHeight() int64 { + // last block of last halving year(12) is last block of emission + // later than this block, no more gns will be minted + return GetHalvingYearEndBlock(HALVING_END_YEAR) +} + +func GetEndTimestamp() int64 { + return GetEmissionState().getEndTimestamp() +} + +func setEndTimestamp(timestamp int64) { + GetEmissionState().setEndTimestamp(timestamp) +} + +func GetHalvingInfo() string { + height := std.GetHeight() + now := time.Now().Unix() + + halvings := make([]*json.Node, 0) + + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + halvings = append(halvings, json.ObjectNode("", map[string]*json.Node{ + "year": json.StringNode("year", strconv.FormatInt(year, 10)), + "block": json.NumberNode("block", float64(GetHalvingYearStartBlock(year))), + "amount": json.NumberNode("amount", float64(GetAmountPerBlockPerHalvingYear(year))), + })) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(height)), + "timestamp": json.NumberNode("timestamp", float64(now)), + "avgBlockTimeMs": json.NumberNode("avgBlockTimeMs", float64(avgBlockTimeMs)), + "halvings": json.ArrayNode("", halvings), + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func isEmissionEnded(height int64) bool { + return height > GetEndHeight() +} + +// getHalvingYearAndEndTimestamp returns the halving year and end timestamp of the given timestamp +// if the timestamp is not in any halving year, it returns 0, 0 +func getHalvingYearAndEndTimestamp(timestamp int64) (int64, int64) { + state := GetEmissionState() + + endTimestamp := state.getEndTimestamp() + startTimestamp := state.getStartTimestamp() + + if timestamp > endTimestamp { // after 12 years + return 0, 0 + } + + timestamp -= startTimestamp + + year := timestamp / consts.TIMESTAMP_YEAR + year += 1 // since we subtract startTimestamp at line 215, we need to add 1 to get the correct year + + return year, startTimestamp + (consts.TIMESTAMP_YEAR * year) +} diff --git a/contract/r/gnoswap/gns/halving_test.gno b/contract/r/gnoswap/gns/halving_test.gno new file mode 100644 index 000000000..0cd736ebf --- /dev/null +++ b/contract/r/gnoswap/gns/halving_test.gno @@ -0,0 +1,203 @@ +package gns + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/json" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" +) + +var ( + govRealm = std.NewCodeRealm(consts.GOV_GOVERNANCE_PATH) + adminRealm = std.NewUserRealm(consts.ADMIN) + + startHeight int64 = std.GetHeight() + startTimestamp int64 = time.Now().Unix() + consts.BLOCK_GENERATION_INTERVAL +) + +var FIRST_BLOCK_OF_YEAR = []int64{ + startHeight + (blockPerYear * 0), + startHeight + (blockPerYear * 1) + 1, + startHeight + (blockPerYear * 2) + 2, + startHeight + (blockPerYear * 3) + 3, + startHeight + (blockPerYear * 4) + 4, + startHeight + (blockPerYear * 5) + 5, + startHeight + (blockPerYear * 6) + 6, + startHeight + (blockPerYear * 7) + 7, + startHeight + (blockPerYear * 8) + 8, + startHeight + (blockPerYear * 9) + 9, + startHeight + (blockPerYear * 10) + 10, + startHeight + (blockPerYear * 11) + 11, +} + +var FIRST_TIMESTAMP_OF_YEAR = []int64{ + startTimestamp + consts.TIMESTAMP_YEAR*0, + startTimestamp + consts.TIMESTAMP_YEAR*1, + startTimestamp + consts.TIMESTAMP_YEAR*2, + startTimestamp + consts.TIMESTAMP_YEAR*3, + startTimestamp + consts.TIMESTAMP_YEAR*4, + startTimestamp + consts.TIMESTAMP_YEAR*5, + startTimestamp + consts.TIMESTAMP_YEAR*6, + startTimestamp + consts.TIMESTAMP_YEAR*7, + startTimestamp + consts.TIMESTAMP_YEAR*8, + startTimestamp + consts.TIMESTAMP_YEAR*9, + startTimestamp + consts.TIMESTAMP_YEAR*10, + startTimestamp + consts.TIMESTAMP_YEAR*11, +} + +var END_TIMESTAMP_OF_YEAR = []int64{ + startTimestamp + consts.TIMESTAMP_YEAR*1, + startTimestamp + consts.TIMESTAMP_YEAR*2, + startTimestamp + consts.TIMESTAMP_YEAR*3, + startTimestamp + consts.TIMESTAMP_YEAR*4, + startTimestamp + consts.TIMESTAMP_YEAR*5, + startTimestamp + consts.TIMESTAMP_YEAR*6, + startTimestamp + consts.TIMESTAMP_YEAR*7, + startTimestamp + consts.TIMESTAMP_YEAR*8, + startTimestamp + consts.TIMESTAMP_YEAR*9, + startTimestamp + consts.TIMESTAMP_YEAR*10, + startTimestamp + consts.TIMESTAMP_YEAR*11, + startTimestamp + consts.TIMESTAMP_YEAR*12, +} + +func TestGetAmountByHeight(t *testing.T) { + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + firstBlockOfYear := FIRST_BLOCK_OF_YEAR[year-1] + uassert.Equal(t, GetAmountPerBlockPerHalvingYear(year), GetAmountByHeight(firstBlockOfYear)) + if year == HALVING_START_YEAR { + uassert.Equal(t, GetHalvingYearAccuAmount(year), GetHalvingYearMaxAmount(year)) + } else { + uassert.Equal(t, GetHalvingYearAccuAmount(year), GetHalvingYearAccuAmount(year-1)+GetHalvingYearMaxAmount(year)) + uassert.Equal(t, int64(1), GetHalvingYearStartBlock(year)-GetHalvingYearEndBlock(year-1)) + } + uassert.Equal(t, blockPerYear, (GetHalvingYearEndBlock(year) - GetHalvingYearStartBlock(year) + 1)) + + if year == HALVING_START_YEAR { + uassert.Equal(t, GetHalvingYearMaxAmount(year)+GetAmountPerBlockPerHalvingYear(year+1), calculateAmountToMint(GetHalvingYearStartBlock(year), GetHalvingYearStartBlock(year+1))) + } else if year == HALVING_END_YEAR { + uassert.Equal(t, GetHalvingYearMaxAmount(year)-GetAmountPerBlockPerHalvingYear(year), calculateAmountToMint(GetHalvingYearStartBlock(year)+1, GetHalvingYearEndBlock(year)+1)) + } else { + uassert.Equal(t, GetHalvingYearMaxAmount(year)-GetAmountPerBlockPerHalvingYear(year)+GetAmountPerBlockPerHalvingYear(year+1), calculateAmountToMint(GetHalvingYearStartBlock(year)+1, GetHalvingYearStartBlock(year+1))) + } + } +} + +func TestGetHalvingYearByHeight(t *testing.T) { + t.Run("during halving years", func(t *testing.T) { + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + firstBlockOfYear := FIRST_BLOCK_OF_YEAR[year-1] + uassert.Equal(t, year, GetHalvingYearByHeight(firstBlockOfYear)) + } + }) + + t.Run("no year after 12 years", func(t *testing.T) { + uassert.Equal(t, int64(0), GetHalvingYearByHeight(GetEndHeight()+1)) + }) +} + +func TestGetHalvingYearAndEndTimestamp(t *testing.T) { + t.Run("bit of extra timestamp for each year", func(t *testing.T) { + for year := HALVING_START_YEAR; year <= HALVING_END_YEAR; year++ { + firstTimestampOfYear := FIRST_TIMESTAMP_OF_YEAR[year-1] + gotYear, gotEndTimestamp := getHalvingYearAndEndTimestamp(firstTimestampOfYear + 5) // after 5s + uassert.Equal(t, year, gotYear) + uassert.Equal(t, gotEndTimestamp, END_TIMESTAMP_OF_YEAR[year-1]) + } + }) + + t.Run("after 12 years", func(t *testing.T) { + year, endTimestamp := getHalvingYearAndEndTimestamp(GetEndTimestamp() + 1) + uassert.Equal(t, int64(0), year) + uassert.Equal(t, int64(0), endTimestamp) + }) +} + +func TestHalvingYearStartBlock(t *testing.T) { + setHalvingYearStartBlock(1, 100) + uassert.Equal(t, GetHalvingYearStartBlock(1), int64(100)) +} + +func TestHalvingYearTimestamp(t *testing.T) { + setHalvingYearTimestamp(2, 200) + uassert.Equal(t, GetHalvingYearTimestamp(2), int64(200)) +} + +func TestHalvingYearMaxAmount(t *testing.T) { + setHalvingYearMaxAmount(3, 300) + uassert.Equal(t, GetHalvingYearMaxAmount(3), uint64(300)) +} + +func TestHalvingYearMintAmount(t *testing.T) { + setHalvingYearMintAmount(4, 400) + uassert.Equal(t, GetHalvingYearMintAmount(4), uint64(400)) +} + +func TestHalvingYearAccuAmount(t *testing.T) { + setHalvingYearAccuAmount(5, 500) + uassert.Equal(t, GetHalvingYearAccuAmount(5), uint64(500)) +} + +func TestAmountPerBlockPerHalvingYear(t *testing.T) { + setAmountPerBlockPerHalvingYear(6, 600) + uassert.Equal(t, GetAmountPerBlockPerHalvingYear(6), uint64(600)) +} + +func TestGetHalvingInfo(t *testing.T) { + jsonStr, err := json.Unmarshal([]byte(GetHalvingInfo())) + uassert.NoError(t, err) + + halving := jsonStr.MustKey("halvings").MustArray() + uassert.Equal(t, len(halving), 12) +} + +func TestSetAvgBlockTimeInMsByAdmin(t *testing.T) { + t.Run("panic if caller is not admin", func(t *testing.T) { + oldCallback := callbackEmissionChange + callbackEmissionChange = func(amount uint64) {} + uassert.PanicsWithMessage(t, + "caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission", + func() { + SetAvgBlockTimeInMsByAdmin(1) + }, + ) + callbackEmissionChange = oldCallback + }) + + t.Run("success if caller is admin", func(t *testing.T) { + oldCallback := callbackEmissionChange + callbackEmissionChange = func(amount uint64) {} + std.TestSkipHeights(1) + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + SetAvgBlockTimeInMsByAdmin(2) + uassert.Equal(t, GetAvgBlockTimeInMs(), int64(2)) + callbackEmissionChange = oldCallback + }) +} + +func TestSetAvgBlockTimeInMs(t *testing.T) { + t.Run("panic if caller is not governance contract", func(t *testing.T) { + oldCallback := callbackEmissionChange + callbackEmissionChange = func(amount uint64) {} + uassert.PanicsWithMessage(t, + "caller(g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm) has no permission", + func() { + SetAvgBlockTimeInMs(3) + }, + ) + callbackEmissionChange = oldCallback + }) + + t.Run("success if caller is governance contract", func(t *testing.T) { + oldCallback := callbackEmissionChange + callbackEmissionChange = func(amount uint64) {} + std.TestSkipHeights(3) + std.TestSetRealm(govRealm) + SetAvgBlockTimeInMs(4) + uassert.Equal(t, GetAvgBlockTimeInMs(), int64(4)) + callbackEmissionChange = oldCallback + }) +} diff --git a/contract/r/gnoswap/gns/tests/z1_filetest.gno b/contract/r/gnoswap/gns/tests/z1_filetest.gno new file mode 100644 index 000000000..95b913eb0 --- /dev/null +++ b/contract/r/gnoswap/gns/tests/z1_filetest.gno @@ -0,0 +1,108 @@ +// change block avg time from 2s to 2.5s + +// PKGPATH: gno.land/r/gnoswap/v1/gns_test +package gns_test + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +var t *testing.T + +var ( + user01Addr = testutils.TestAddress("user01Addr") + user01Realm = std.NewUserRealm(user01Addr) +) + +func main() { + skip50Blocks() + blockTime2500ms() + reachAlmostFirstHalving() + reachExactFirstHalving() + startSecondHalving() +} + +func skip50Blocks() { + std.TestSkipHeights(50) + uassert.Equal(t, std.GetHeight(), int64(173)) + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + mintedAmount := gns.MintGns(user01Addr) + uassert.Equal(t, uint64(713470300), mintedAmount) // 14269406 * 50 + uassert.Equal(t, uint64(713470300), gns.GetMintedEmissionAmount()) +} + +func blockTime2500ms() { + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + gns.SetAvgBlockTimeInMsByAdmin(2500) + // for block time 2.5s, amount per block is 17836759 + + // first halving year end block = 15768122 + // first halving year end timestamp = 1234567990 + + // 50 block minted from L#38 + // > current height = 173 + // > current timestamp = 1234568140 + + // 1266103890 - 1234567990 = 31535900 // timestamp left for current halving year + // 31535900000 / 2500(new block avg time/ms) = 12614360 // number of block left for current halving year + + // 713470300 // already minted amount + + // first halving year should mint 225000000000000 + // 225000000000000 - 713470300 = 224999286529700 + // 224999286529700 / 12614360 = 17836759 + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(1) + mintedAmount := gns.MintGns(user01Addr) + uassert.Equal(t, uint64(17836757), mintedAmount) + uassert.Equal(t, uint64(713470300+17836757), gns.GetMintedEmissionAmount()) // 731307057 +} + +func reachAlmostFirstHalving() { + + firstEnds := gns.GetHalvingYearEndBlock(1) + blockLeftUntilEnd := firstEnds - std.GetHeight() + + std.TestSkipHeights(blockLeftUntilEnd - 1) // 1 block left until first halving year ends + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + + uassert.Equal(t, uint64(224999969664063), gns.GetMintedEmissionAmount()) + // 224999969664063 - 731307057 = 224999238357006 + // 224999238357006 / 12614358 = 17836757 +} + +func reachExactFirstHalving() { + std.TestSkipHeights(1) + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + + // minted all amount for first halving year + uassert.Equal(t, uint64(225000000000000), gns.GetMintedEmissionAmount()) + + year := gns.GetHalvingYearByHeight(std.GetHeight()) + uassert.Equal(t, int64(1), year) +} + +func startSecondHalving() { + std.TestSkipHeights(1) + + year := gns.GetHalvingYearByHeight(std.GetHeight()) + uassert.Equal(t, int64(2), year) + + amount := gns.GetAmountByHeight(std.GetHeight()) + uassert.Equal(t, uint64(14269406), amount) + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + uassert.Equal(t, uint64(225000000000000+14269406), gns.GetMintedEmissionAmount()) +} diff --git a/contract/r/gnoswap/gns/tests/z2_filetest.gno b/contract/r/gnoswap/gns/tests/z2_filetest.gno new file mode 100644 index 000000000..399d4c5f4 --- /dev/null +++ b/contract/r/gnoswap/gns/tests/z2_filetest.gno @@ -0,0 +1,131 @@ +// change block avg time from 2s to 4s + +// PKGPATH: gno.land/r/gnoswap/v1/gns_test +package gns_test + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +var t *testing.T + +var ( + user01Addr = testutils.TestAddress("user01Addr") + user01Realm = std.NewUserRealm(user01Addr) +) + +func main() { + skip50Blocks() + blockTime4000ms() + reachFirstHalving() + startSecondHalving() + reachSecondHalving() +} + +func skip50Blocks() { + std.TestSkipHeights(50) + uassert.Equal(t, std.GetHeight(), int64(173)) + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) // 14269406 * 50 = 713470300 + uassert.Equal(t, uint64(713470300), gns.GetMintedEmissionAmount()) +} + +func blockTime4000ms() { + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + gns.SetAvgBlockTimeInMsByAdmin(4000) + // for block time 4s, amount per block is 28538812 + + // first halving year end block = 15768122 + // first halving year end timestamp = 1234567990 + + // 50 block minted from L#38 + // > current height = 173 + // > current timestamp = 1234568140 + + // 1266103890 - 1234567990 = 31535900 // timestamp left for current halving year + // 31535900000 / 4000(new block avg time/ms) = 7883975 // number of block left for current halving year + + // 713470300 // already minted amount + + // first halving year should mint 225000000000000 + // 225000000000000 - 713470300 = 224999286529700 + // 224999286529700 / 7883975 ≈ 28538812 + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(1) + mintedAmount := gns.MintGns(user01Addr) // 28538812 + uassert.Equal(t, uint64(713470300+28538812), gns.GetMintedEmissionAmount()) // 742009112 + + firstYearAmountPerBlock := gns.GetAmountPerBlockPerHalvingYear(1) + uassert.Equal(t, uint64(28538812), firstYearAmountPerBlock) + + // next halving year start/end block + uassert.Equal(t, int64(7884149), gns.GetHalvingYearStartBlock(2)) + uassert.Equal(t, int64(15768149), gns.GetHalvingYearEndBlock(2)) + + // orig year01 start block = 123 + // 50 block mined from L#38 + // current block = 173 + // current timestamp = 1234567990 + // orig year01 end timestamp = 1266103890 + + // 1266103890 - 1234567990 = 31535900 // timestamp left for current halving year + // 31535900000(ms) / 4000(ms) = 7883975 // number of block left for current halving year + + // 173 + 7883975 = 7884148 // end block of current halving year + // 7884148 + 1 = 7884149 // start block of next halving year + + // 7884000 // based on 4s block, how many block in a year + // 7884149 + 7884000 = 15768149 // end block of next halving year +} + +func reachFirstHalving() { + // current := 174 + // nextHalving := 7884148 + // 7884148 - 174 = 7883974 + + std.TestSkipHeights(7883974) + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + + // minted all amount for first halving year + uassert.Equal(t, uint64(225000000000000), gns.GetMintedEmissionAmount()) + + year := gns.GetHalvingYearByHeight(std.GetHeight()) + uassert.Equal(t, int64(1), year) +} + +func startSecondHalving() { + std.TestSkipHeights(1) + + year := gns.GetHalvingYearByHeight(std.GetHeight()) + uassert.Equal(t, int64(2), year) + + amount := gns.GetAmountByHeight(std.GetHeight()) + uassert.Equal(t, uint64(14269406), amount) + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + uassert.Equal(t, uint64(225000000000000+14269406), gns.GetMintedEmissionAmount()) +} + +func reachSecondHalving() { + blockLeftUntilEndOfYear02 := gns.GetHalvingYearEndBlock(2) - std.GetHeight() + + std.TestSkipHeights(int64(blockLeftUntilEndOfYear02)) + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + + // minted all amount until second halving + year01Max := gns.GetHalvingYearMaxAmount(1) + year02Max := gns.GetHalvingYearMaxAmount(2) + uassert.Equal(t, year01Max+year02Max, gns.GetMintedEmissionAmount()) +} diff --git a/contract/r/gnoswap/gns/tests/z3_filetest.gno b/contract/r/gnoswap/gns/tests/z3_filetest.gno new file mode 100644 index 000000000..edcff7cf4 --- /dev/null +++ b/contract/r/gnoswap/gns/tests/z3_filetest.gno @@ -0,0 +1,103 @@ +// change block avg time from 2s > 4s> 3s + +// PKGPATH: gno.land/r/gnoswap/v1/gns_test +package gns_test + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +var t *testing.T + +var ( + user01Addr = testutils.TestAddress("user01Addr") + user01Realm = std.NewUserRealm(user01Addr) +) + +func main() { + skip50Blocks() + blockTime4000ms() // 4s + reachAboutHalfOfFirstHalving() + blockTime3000ms() // 3s + reachFirstHalving() + startSecondHalving() +} + +func skip50Blocks() { + std.TestSkipHeights(50) + uassert.Equal(t, std.GetHeight(), int64(173)) + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr)) // 14269406 * 50 = 71347030 + uassert.Equal(t, uint64(713470300), gns.GetMintedEmissionAmount()) +} + +func blockTime4000ms() { + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + gns.SetAvgBlockTimeInMsByAdmin(4000) + + std.TestSkipHeights(1) + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr)) // 2853881 + uassert.Equal(t, uint64(713470300+28538812), gns.GetMintedEmissionAmount()) // 742009112 + + // next halving year start/end block + uassert.Equal(t, int64(7884149), gns.GetHalvingYearStartBlock(2)) + uassert.Equal(t, int64(15768149), gns.GetHalvingYearEndBlock(2)) +} + +func reachAboutHalfOfFirstHalving() { + // current block = 174 + // end block of first halving year = 7884148 + // 7884148 - 174 = 7883974 // block left to next halving + + std.TestSkipHeights(3941987) + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + + // stil in first halving year + year := gns.GetHalvingYearByHeight(std.GetHeight()) + uassert.Equal(t, int64(1), year) +} + +func blockTime3000ms() { + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + gns.SetAvgBlockTimeInMsByAdmin(3000) + std.TestSkipHeights(1) +} + +func reachFirstHalving() { + blockLeftUntilEndOfYear01 := gns.GetHalvingYearEndBlock(1) - std.GetHeight() + std.TestSkipHeights(blockLeftUntilEndOfYear01) + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + + // minted all amount for first halving year + uassert.Equal(t, gns.GetHalvingYearMaxAmount(1), gns.GetMintedEmissionAmount()) + + // we're at the end of first halving year + year := gns.GetHalvingYearByHeight(std.GetHeight()) + uassert.Equal(t, int64(1), year) +} + +func startSecondHalving() { + std.TestSkipHeights(1) + + year := gns.GetHalvingYearByHeight(std.GetHeight()) + uassert.Equal(t, int64(2), year) + + amount := gns.GetAmountByHeight(std.GetHeight()) + uassert.Equal(t, uint64(14269406), amount) + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + gns.MintGns(user01Addr) + uassert.Equal(t, uint64(225000000000000+14269406), gns.GetMintedEmissionAmount()) +} diff --git a/contract/r/gnoswap/gns/tests/z4_filetest.gno b/contract/r/gnoswap/gns/tests/z4_filetest.gno new file mode 100644 index 000000000..3c262c1a9 --- /dev/null +++ b/contract/r/gnoswap/gns/tests/z4_filetest.gno @@ -0,0 +1,88 @@ +// reach all having years and minted all amount +// and then no gns will be minted + +// PKGPATH: gno.land/r/gnoswap/v1/gns_test +package gns_test + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +var t *testing.T + +var ( + user01Addr = testutils.TestAddress("user01Addr") + user01Realm = std.NewUserRealm(user01Addr) +) + +func main() { + reachAlmostEndOfEmission() + skip50Blocks() + moreBlocksAfterEmissionEnds() +} + +func reachAlmostEndOfEmission() { + endHeight := gns.GetEndHeight() + untilEnds := endHeight - std.GetHeight() + std.TestSkipHeights(untilEnds - 10) // 10 blocks before end of emission + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + minted := gns.MintGns(user01Addr) + + // minted all amount for year 01 ~ 11 + for year := int64(1); year <= 11; year++ { + uassert.Equal( + t, + gns.GetHalvingYearMaxAmount(year), + gns.GetHalvingYearMintAmount(year), + ufmt.Sprintf("shoud been minted max amount for year %d", year), + ) + } + + // we're at the end of last halving year + year := gns.GetHalvingYearByHeight(std.GetHeight()) + uassert.Equal(t, int64(12), year) +} + +func skip50Blocks() { + // 10 block left until emission ends + // skipping 50 blocks will reach the end of emission + // and it should mint for only 10 blocks + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(50) + gns.MintGns(user01Addr) + + // total supply + totalSupply := gns.TotalSupply() + uassert.Equal(t, gns.MAXIMUM_SUPPLY, totalSupply) + + // minted amount + mintedAmount := gns.GetMintedEmissionAmount() + uassert.Equal(t, gns.MAX_EMISSION_AMOUNT, mintedAmount) +} + +func moreBlocksAfterEmissionEnds() { + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(10) + + // no gns will be minted after emission ends + minted := gns.MintGns(user01Addr) + uassert.Equal(t, uint64(0), minted) + + // total supply + totalSupply := gns.TotalSupply() + uassert.Equal(t, gns.MAXIMUM_SUPPLY, totalSupply) + + // minted amount + mintedAmount := gns.GetMintedEmissionAmount() + uassert.Equal(t, gns.MAX_EMISSION_AMOUNT, mintedAmount) +} diff --git a/contract/r/gnoswap/gns/utils.gno b/contract/r/gnoswap/gns/utils.gno new file mode 100644 index 000000000..2ea700e81 --- /dev/null +++ b/contract/r/gnoswap/gns/utils.gno @@ -0,0 +1,88 @@ +package gns + +import ( + "std" + "strconv" + + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +func getPrev() (string, string) { + prev := std.PrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +func assertShouldNotBeHalted() { + common.IsHalted() +} + +func assertCallerIsEmission() { + caller := getPrevAddr() + if err := common.EmissionOnly(caller); err != nil { + panic(err) + } +} + +func assertCallerIsAdmin() { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } +} + +func assertCallerIsGovernance() { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err) + } +} + +func assertValidYear(year int64) { + if year < HALVING_START_YEAR || year > HALVING_END_YEAR { + panic(addDetailToError(errInvalidYear, ufmt.Sprintf("year: %d", year))) + } +} + +func i64Min(x, y int64) int64 { + if x < y { + return x + } + return y +} + +func secToMs(sec int64) int64 { + return sec * consts.SECOND_IN_MILLISECOND +} + +func formatUint(v interface{}) string { + switch v := v.(type) { + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatInt(v interface{}) string { + switch v := v.(type) { + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case int: + return strconv.Itoa(v) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} diff --git a/contract/r/gnoswap/gov/doc.gno b/contract/r/gnoswap/gov/doc.gno new file mode 100644 index 000000000..c3ea48e4b --- /dev/null +++ b/contract/r/gnoswap/gov/doc.gno @@ -0,0 +1,2 @@ +// Package gov manages the governance of the gnoswap protocol, including proposal creation, voting, and execution. +package gov diff --git a/contract/r/gnoswap/gov/governance/api.gno b/contract/r/gnoswap/gov/governance/api.gno new file mode 100644 index 000000000..063c92212 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/api.gno @@ -0,0 +1,260 @@ +package governance + +import ( + "std" + "strconv" + "strings" + "time" + + "gno.land/p/demo/json" + + en "gno.land/r/gnoswap/v1/emission" +) + +func createProposalJsonNode(id uint64, proposal ProposalInfo) *json.Node { + return json.Builder(). + WriteString("id", formatUint(id)). + WriteString("configVersion", formatUint(proposal.ConfigVersion)). + WriteString("proposer", proposal.Proposer.String()). + WriteString("status", b64Encode(getProposalStatus(id))). + WriteString("type", proposal.ProposalType.String()). + WriteString("title", proposal.Title). + WriteString("description", proposal.Description). + WriteString("vote", b64Encode(getProposalVotes(id))). + WriteString("extra", b64Encode(getProposalExtraData(id))). + Node() +} + +// GetProposals returns all proposals with necessary information. +func GetProposals() string { + en.MintAndDistributeGns() + updateProposalsState() + + if proposals.Size() == 0 { + return "" + } + + proposalsObj := metaNode() + proposalArr := json.ArrayNode("", nil) + + proposals.Iterate("", "", func(key string, value interface{}) bool { + proposalObj := getProposalById(proposalId) + proposalArr.AppendArray(proposalObj) + return true + }) + + proposalsObj.AppendObject("proposals", proposalArr) + + return marshal(proposalsObj) +} + +// GetProposalById returns a single proposal with necessary information. +func GetProposalById(id uint64) string { + en.MintAndDistributeGns() + updateProposalsState() + + _, exists := proposals.Get(formatUint(id)) + if !exists { + return "" + } + + proposalsObj := metaNode() + proposalArr := json.ArrayNode("", nil) + proposalObj := getProposalById(id) + proposalArr.AppendArray(proposalObj) + proposalsObj.AppendObject("proposals", proposalArr) + + return marshal(proposalsObj) +} + +// helper function for GetProposals and GetProposalById +func getProposalById(id uint64) *json.Node { + proposal := mustGetProposal(id) + return createProposalJsonNode(id, proposal) +} + +// GetVoteStatusFromProposalById returns the vote status(max, yes, no) of a proposal. +func GetVoteStatusFromProposalById(id uint64) string { + en.MintAndDistributeGns() + updateProposalsState() + + _, exists := proposals.Get(formatUint(id)) + if !exists { + return "" + } + + votesObj := metaNode() + votesObj.AppendObject("proposalId", json.StringNode("proposalId", formatUint(id))) + votesObj.AppendObject("votes", json.StringNode("votes", b64Encode(getProposalVotes(id)))) // max, yes, no + + return marshal(votesObj) +} + +// GetVotesByAddress returns all votes of an address. +// included information: +// - proposalId +// - vote (yes/no) +// - weight +// - height +// - timestamp +func GetVotesByAddress(addr std.Address) string { + en.MintAndDistributeGns() + updateProposalsState() + + if _, exist := userVotes.Get(addr.String()); !exist { + return "" + } + + votesObj := metaNode() + votesArr := json.ArrayNode("", nil) + + userVotes.Iterate(addr.String(), "", func(key string, value interface{}) bool { + voteObj := createVoteJsonNode(addr, proposalId, value.(voteWithWeight)) + votesArr.AppendArray(voteObj) + return true + }) + votesObj.AppendObject("votes", votesArr) + + return marshal(votesObj) +} + +// GetVoteByAddressFromProposalById returns the vote of an address from a certain proposal. +func GetVoteByAddressFromProposalById(addr std.Address, id uint64) string { + en.MintAndDistributeGns() + updateProposalsState() + + vote, exists := getUserVote(addr, id) + if !exists { + return "" + } + + votesObj := metaNode() + voteArr := json.ArrayNode("", nil) + voteObj := createVoteJsonNode(addr, id, vote) + voteArr.AppendArray(voteObj) + votesObj.AppendObject("votes", voteArr) + + return marshal(votesObj) +} + +func createVoteJsonNode(addr std.Address, id uint64, vote voteWithWeight) *json.Node { + return json.Builder(). + WriteString("proposalId", formatUint(id)). + WriteString("voteYes", formatBool(vote.Yes)). + WriteString("voteWeight", formatUint(vote.Weight)). + WriteString("voteHeight", formatUint(vote.VotedHeight)). + WriteString("voteTimestamp", formatUint(vote.VotedAt)). + Node() +} + +// getProposalExtraData returns the extra data of a proposal based on its type. +func getProposalExtraData(proposalId uint64) string { + proposal, exist := proposals.Get(formatUint(proposalId)) + if !exist { + return "" + } + + switch proposal.(ProposalInfo).ProposalType { + case Text: + return "" + case CommunityPoolSpend: + return getCommunityPoolSpendProposalData(proposalId) + case ParameterChange: + return getParameterChangeProposalData(proposalId) + } + + return "" +} + +// community pool has three extra data +// 1. to +// 2. tokenPath +// 3. amount +func getCommunityPoolSpendProposalData(proposalId uint64) string { + proposal := mustGetProposal(proposalId) + + proposalObj := json.Builder(). + WriteString("to", proposal.CommunityPoolSpend.To.String()). + WriteString("tokenPath", proposal.CommunityPoolSpend.TokenPath). + WriteString("amount", formatUint(proposal.CommunityPoolSpend.Amount)). + Node() + + return marshal(proposalObj) +} + +// parameter change proposal has three extra data +func getParameterChangeProposalData(proposalId uint64) string { + proposal := mustGetProposal(proposalId) + + msgs := proposal.Execution.Msgs + msgsStr := strings.Join(msgs, "*GOV*") + + return msgsStr +} + +// getProposalStatus returns the status of a proposal. +func getProposalStatus(id uint64) string { + prop, exist := proposals.Get(formatUint(id)) + if !exist { + return "" + } + proposal := prop.(ProposalInfo) + + config := GetConfigVersion(proposal.ConfigVersion) + + votingStart := proposal.State.CreatedAt + config.VotingStartDelay + votingEnd := votingStart + config.VotingPeriod + + node := createProposalStateNode(proposal.State, votingStart, votingEnd) + return marshal(node) +} + +func createProposalStateNode(state ProposalState, votingStart, votingEnd uint64) *json.Node { + return json.Builder(). + WriteString("createdAt", formatUint(state.CreatedAt)). + WriteString("upcoming", formatBool(state.Upcoming)). + WriteString("active", formatBool(state.Active)). + WriteString("votingStart", formatUint(votingStart)). + WriteString("votingEnd", formatUint(votingEnd)). + WriteString("passed", formatBool(state.Passed)). + WriteString("passedAt", formatUint(state.PassedAt)). + WriteString("rejected", formatBool(state.Rejected)). + WriteString("rejectedAt", formatUint(state.RejectedAt)). + WriteString("canceled", formatBool(state.Canceled)). + WriteString("canceledAt", formatUint(state.CanceledAt)). + WriteString("executed", formatBool(state.Executed)). + WriteString("executedAt", formatUint(state.ExecutedAt)). + WriteString("expired", formatBool(state.Expired)). + WriteString("expiredAt", formatUint(state.ExpiredAt)). + Node() +} + +// getProposalVotes returns the votes of a proposal. +func getProposalVotes(id uint64) string { + prop, exist := proposals.Get(formatUint(id)) + if !exist { + return "" + } + + proposal := prop.(ProposalInfo) + maxVoting := proposal.MaxVotingWeight.ToString() + + proposalObj := json.Builder(). + WriteString("quorum", formatUint(proposal.QuorumAmount)). + WriteString("max", maxVoting). + WriteString("yes", proposal.Yea.ToString()). + WriteString("no", proposal.Nay.ToString()). + Node() + + return marshal(proposalObj) +} + +func metaNode() *json.Node { + height := std.GetHeight() + now := time.Now().Unix() + + return json.Builder(). + WriteString("height", strconv.FormatInt(height, 10)). + WriteString("now", strconv.FormatInt(now, 10)). + Node() +} diff --git a/contract/r/gnoswap/gov/governance/api_test.gno b/contract/r/gnoswap/gov/governance/api_test.gno new file mode 100644 index 000000000..7c8507cc2 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/api_test.gno @@ -0,0 +1,186 @@ +package governance + +import ( + "strconv" + "testing" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/json" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" +) + +func TestCreateProposalJsonNode(t *testing.T) { + propAddr := testutils.TestAddress("proposal") + title := "Test Proposal" + desc := "This is a test proposal" + + testProposal := ProposalInfo{ + ConfigVersion: 1, + Proposer: propAddr, + ProposalType: Text, + Title: title, + Description: desc, + } + + result := createProposalJsonNode(123, testProposal) + + expectedFields := []struct { + key string + expected string + }{ + {"id", "123"}, + {"configVersion", "1"}, + {"proposer", propAddr.String()}, + {"type", Text.String()}, + {"title", title}, + {"description", desc}, + } + + for _, field := range expectedFields { + node, err := result.GetKey(field.key) + uassert.NoError(t, err) + + value, err := node.GetString() + uassert.NoError(t, err) + uassert.Equal(t, value, field.expected) + } + + encodedFields := []string{"status", "vote", "extra"} + for _, field := range encodedFields { + node, err := result.GetKey(field) + if err != nil { + t.Errorf("field not found: %s", field) + continue + } + + value, err := node.GetString() + if err != nil { + t.Errorf("failed to get value: %s", err) + continue + } + + decodedValue := b64Decode(value) + } +} + +func TestCreateProposalJsonNode_CheckRequiredFields(t *testing.T) { + emptyProposal := ProposalInfo{ + ConfigVersion: 0, + Proposer: testutils.TestAddress("proposal"), + ProposalType: Text, + Title: "", + Description: "", + } + + result := createProposalJsonNode(0, emptyProposal) + + requiredFields := []string{ + "id", "configVersion", "proposer", "status", + "type", "title", "description", "vote", "extra", + } + + for _, field := range requiredFields { + _, err := result.GetKey(field) + uassert.NoError(t, err) + } +} + +func TestGetProposalStatus(t *testing.T) { + proposals = avl.NewTree() + + // Test Case 1: Non-existent proposal + status := getProposalStatus(999) + uassert.Equal(t, status, "") + + // Test Case 2: Active proposal + now := uint64(time.Now().Unix()) + proposal := ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + CreatedAt: now, + Upcoming: false, + Active: true, + Passed: false, + PassedAt: 0, + Rejected: false, + RejectedAt: 0, + Canceled: false, + CanceledAt: 0, + Executed: false, + ExecutedAt: 0, + Expired: false, + ExpiredAt: 0, + }, + } + + proposals.Set("1", proposal) + + status = getProposalStatus(1) + + node, err := json.Unmarshal([]byte(status)) + uassert.NoError(t, err) + uassert.True(t, node.IsObject()) + + tests := []struct { + key string + expected string + }{ + {"createdAt", strconv.FormatUint(now, 10)}, + {"upcoming", "false"}, + {"active", "true"}, + {"passed", "false"}, + {"passedAt", "0"}, + {"rejected", "false"}, + {"rejectedAt", "0"}, + {"canceled", "false"}, + {"canceledAt", "0"}, + {"executed", "false"}, + {"executedAt", "0"}, + {"expired", "false"}, + {"expiredAt", "0"}, + } + + for _, tc := range tests { + uassert.True(t, node.HasKey(tc.key)) + + value, err := node.GetKey(tc.key) + uassert.NoError(t, err) + uassert.True(t, value.IsString()) + + str, err := value.GetString() + uassert.NoError(t, err) + uassert.Equal(t, str, tc.expected) + } + + // Test Case 3: State transition + proposal.State.Active = false + proposal.State.Passed = true + proposal.State.PassedAt = now + 500 + proposal.State.Executed = true + proposal.State.ExecutedAt = now + 1000 + + proposals.Set("2", proposal) + + status = getProposalStatus(2) + node2, err := json.Unmarshal([]byte(status)) + uassert.NoError(t, err) + + // Test node traversal using ObjectEach + expectedFields := map[string]string{ + "active": "false", + "passed": "true", + "passedAt": strconv.FormatUint(now+500, 10), + "executed": "true", + "executedAt": strconv.FormatUint(now+1000, 10), + } + + node2.ObjectEach(func(key string, value *json.Node) { + if expected, ok := expectedFields[key]; ok { + str, err := value.GetString() + uassert.NoError(t, err) + uassert.Equal(t, str, expected) + } + }) +} diff --git a/contract/r/gnoswap/gov/governance/config.gno b/contract/r/gnoswap/gov/governance/config.gno new file mode 100644 index 000000000..7119f3ed8 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/config.gno @@ -0,0 +1,166 @@ +package governance + +import ( + "std" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + + "gno.land/r/gnoswap/v1/common" + en "gno.land/r/gnoswap/v1/emission" +) + +var ( + config Config + configVersions = avl.NewTree() // version -> Config +) + +func init() { + // https://docs.gnoswap.io/core-concepts/governance + config = Config{ + VotingStartDelay: uint64(86400), // 1d + VotingPeriod: uint64(604800), // 7d + VotingWeightSmoothingDuration: uint64(86400), // 1d + Quorum: uint64(50), // 50% of total xGNS supply + ProposalCreationThreshold: uint64(1_000_000_000), // 1_000_000_000 + ExecutionDelay: uint64(86400), // 1d + ExecutionWindow: uint64(2592000), // 30d + } + + // config version 0 should return the current config + // therefore we set initial config version to 1 + setConfigVersion(1, config) +} + +func setConfigVersion(v uint64, cfg Config) { + configVersions.Set(formatUint(v), cfg) +} + +func getConfigByVersion(v uint64) (Config, bool) { + value, exists := configVersions.Get(formatUint(v)) + if !exists { + return Config{}, false + } + return value.(Config), true +} + +func getLatestVersion() uint64 { + var maxVersion uint64 + configVersions.ReverseIterate("", "", func(key string, value interface{}) bool { + maxVersion = parseUint64(key) + return true + }) + return maxVersion +} + +// ReconfigureByAdmin updates the proposal realted configuration. +// Returns the new configuration version number. +func ReconfigureByAdmin( + votingStartDelay uint64, + votingPeriod uint64, + votingWeightSmoothingDuration uint64, + quorum uint64, + proposalCreationThreshold uint64, + executionDelay uint64, + executionWindow uint64, +) uint64 { + caller := std.PrevRealm().Addr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } + + return reconfigure( + votingStartDelay, + votingPeriod, + votingWeightSmoothingDuration, + quorum, + proposalCreationThreshold, + executionDelay, + executionWindow, + ) +} + +// reconfigure updates the Governor's configuration. +// Only governance contract can execute this function via proposal +// Returns the new configuration version number. +func reconfigure( + votingStartDelay uint64, + votingPeriod uint64, + votingWeightSmoothingDuration uint64, + quorum uint64, + proposalCreationThreshold uint64, + executionDelay uint64, + executionWindow uint64, +) uint64 { + common.IsHalted() + + en.MintAndDistributeGns() + updateProposalsState() + + prevVersion := getLatestVersion() + newVersion := prevVersion + 1 + + config = Config{ + VotingStartDelay: votingStartDelay, + VotingPeriod: votingPeriod, + VotingWeightSmoothingDuration: votingWeightSmoothingDuration, + Quorum: quorum, + ProposalCreationThreshold: proposalCreationThreshold, + ExecutionDelay: executionDelay, + ExecutionWindow: executionWindow, + } + setConfigVersion(newVersion, config) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "Reconfigure", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "votingStartDelay", formatUint(config.VotingStartDelay), + "votingPeriod", formatUint(config.VotingPeriod), + "votingWeightSmoothingDuration", formatUint(config.VotingWeightSmoothingDuration), + "quorum", formatUint(config.Quorum), + "proposalCreationThreshold", formatUint(config.ProposalCreationThreshold), + "executionDelay", formatUint(config.ExecutionDelay), + "executionPeriod", formatUint(config.ExecutionWindow), + "newConfigVersion", formatUint(newVersion), + "prevConfigVersion", formatUint(prevVersion), + ) + + return newVersion +} + +// GetConfigVersion returns the configuration for a specific version. +// If version is 0, it returns the current configuration. +func GetConfigVersion(version uint64) Config { + en.MintAndDistributeGns() + + if version == 0 { + return config + } + + cfg, exists := getConfigByVersion(version) + if !exists { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("config version(%d) does not exist", version), + )) + } + + return cfg +} + +// GetLatestConfig() returns the latest configuration. +func GetLatestConfig() Config { + return config +} + +// GetLatestConfigVersion() returns the latest configuration version. +func GetLatestConfigVersion() uint64 { + return getLatestVersion() +} + +// GetProposalCreationThreshold() returns the current proposal creation threshold. +func GetProposalCreationThreshold() uint64 { + return config.ProposalCreationThreshold +} diff --git a/contract/r/gnoswap/gov/governance/config_test.gno b/contract/r/gnoswap/gov/governance/config_test.gno new file mode 100644 index 000000000..aec7bb451 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/config_test.gno @@ -0,0 +1,219 @@ +package governance + +import ( + "std" + "testing" + + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" +) + +var ( + adminAddr = testutils.TestAddress("admin") + notAdminAddr = testutils.TestAddress("notadmin") +) + +// we cannot call init() in test, so we mock it here. +// it's totally same as init() +func mockInit(t *testing.T) { + t.Helper() + config = Config{ + VotingStartDelay: uint64(86400), + VotingPeriod: uint64(604800), + VotingWeightSmoothingDuration: uint64(86400), + Quorum: uint64(50), + ProposalCreationThreshold: uint64(1_000_000_000), + ExecutionDelay: uint64(86400), + ExecutionWindow: uint64(2592000), + } + + // configVersions[uint64(len(configVersions)+1)] = config + configVersions = avl.NewTree() + setConfigVersion(1, config) +} + +func resetGlobalConfig(t *testing.T) { + t.Helper() + config = Config{} + // configVersions = make(map[uint64]Config) + configVersions = avl.NewTree() +} + +func newConfigObj(t *testing.T) Config { + return Config{ + VotingStartDelay: 100000, + VotingPeriod: 700000, + VotingWeightSmoothingDuration: 90000, + Quorum: 60, + ProposalCreationThreshold: 2_000_000_000, + ExecutionDelay: 90000, + ExecutionWindow: 3000000, + } +} + +func TestInitialConfig(t *testing.T) { + resetGlobalConfig(t) + mockInit(t) + + tests := []struct { + name string + got uint64 + expected uint64 + }{ + {"VotingStartDelay", config.VotingStartDelay, 86400}, + {"VotingPeriod", config.VotingPeriod, 604800}, + {"VotingWeightSmoothingDuration", config.VotingWeightSmoothingDuration, 86400}, + {"Quorum", config.Quorum, 50}, + {"ProposalCreationThreshold", config.ProposalCreationThreshold, 1_000_000_000}, + {"ExecutionDelay", config.ExecutionDelay, 86400}, + {"ExecutionWindow", config.ExecutionWindow, 2592000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uassert.Equal(t, tt.got, tt.expected) + }) + } + + uassert.Equal(t, configVersions.Size(), 1) +} + +func TestReconfigureByAdmin(t *testing.T) { + resetGlobalConfig(t) + mockInit(t) + + tests := []struct { + name string + caller std.Address + newConfig Config + expectError bool + errorContains string + }{ + { + name: "Valid admin reconfiguration", + caller: adminAddr, + newConfig: Config{ + VotingStartDelay: 86400, + VotingPeriod: 700000, + VotingWeightSmoothingDuration: 90000, + Quorum: 60, + ProposalCreationThreshold: 2_000_000_000, + ExecutionDelay: 90000, + ExecutionWindow: 3000000, + }, + expectError: false, + }, + { + name: "Non-admin caller", + caller: testutils.TestAddress("notadmin"), + newConfig: Config{ + VotingStartDelay: 86400, + }, + expectError: true, + errorContains: "admin", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var err error + func() { + defer func() { + if r := recover(); r != nil { + err = r.(error) + } + }() + + ReconfigureByAdmin( + tt.newConfig.VotingStartDelay, + tt.newConfig.VotingPeriod, + tt.newConfig.VotingWeightSmoothingDuration, + tt.newConfig.Quorum, + tt.newConfig.ProposalCreationThreshold, + tt.newConfig.ExecutionDelay, + tt.newConfig.ExecutionWindow, + ) + }() + + if tt.expectError { + uassert.Error(t, err) + } + uassert.Equal(t, config.VotingStartDelay, tt.newConfig.VotingStartDelay) + }) + } +} + +func TestGetConfigVersion(t *testing.T) { + // don't call mockInit here. + resetGlobalConfig(t) + + tests := []struct { + name string + version uint64 + expectError bool + }{ + { + name: "Get current config (version 0)", + version: 0, + expectError: false, + }, + { + name: "Get existing version", + version: 1, + expectError: false, + }, + { + name: "Get non-existent version", + version: 999, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result Config + + if tt.expectError { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic, got none") + } + }() + result = GetConfigVersion(tt.version) + } + + expectedConfig := config + if tt.version != 0 { + expectedConfig, _ = getConfigByVersion(tt.version) + } + uassert.Equal(t, result.VotingStartDelay, expectedConfig.VotingStartDelay) + }) + } +} + +func TestReconfigure(t *testing.T) { + resetGlobalConfig(t) + mockInit(t) + + newConfig := newConfigObj(t) + + initialVersion := GetLatestConfigVersion() + + version := reconfigure( + newConfig.VotingStartDelay, + newConfig.VotingPeriod, + newConfig.VotingWeightSmoothingDuration, + newConfig.Quorum, + newConfig.ProposalCreationThreshold, + newConfig.ExecutionDelay, + newConfig.ExecutionWindow, + ) + + uassert.Equal(t, version, initialVersion+1) + uassert.Equal(t, config.VotingStartDelay, newConfig.VotingStartDelay) + + storedConfig, _ := getConfigByVersion(version) + uassert.Equal(t, storedConfig.VotingStartDelay, newConfig.VotingStartDelay) +} diff --git a/contract/r/gnoswap/gov/governance/errors.gno b/contract/r/gnoswap/gov/governance/errors.gno new file mode 100644 index 000000000..4f55cbe4f --- /dev/null +++ b/contract/r/gnoswap/gov/governance/errors.gno @@ -0,0 +1,31 @@ +package governance + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errOutOfRange = errors.New("[GNOSWAP-GOVERNANCE-001] out of range for numeric value") + errInvalidInput = errors.New("[GNOSWAP-GOVERNANCE-002] invalid input") + errDataNotFound = errors.New("[GNOSWAP-GOVERNANCE-003] requested data not found") + errNotEnoughBalance = errors.New("[GNOSWAP-GOVERNANCE-004] not enough balance") + errUnableToVoteCanceledProposal = errors.New("[GNOSWAP-GOVERNANCE-005] unable to vote for canceled proposal") + errAlreadyVoted = errors.New("[GNOSWAP-GOVERNANCE-006] can not vote twice") + errNotEnoughVotingWeight = errors.New("[GNOSWAP-GOVERNANCE-007] not enough voting power") + errAlreadyCanceledProposal = errors.New("[GNOSWAP-GOVERNANCE-008] can not cancel already canceled proposal") + errUnableToCancleVotingProposal = errors.New("[GNOSWAP-GOVERNANCE-009] unable to cancel voting proposal") + errUnableToCancelProposalWithVoterEnoughDelegated = errors.New("[GNOSWAP-GOVERNANCE-010] unable to cancel proposal with voter has enough delegation") + errTextProposalNotExecutable = errors.New("[GNOSWAP-GOVERNANCE-011] can not execute text proposal") + errUnsupportedProposalType = errors.New("[GNOSWAP-GOVERNANCE-012] unsupported proposal type") + errInvalidProposalType = errors.New("[GNOSWAP-GOVERNANCE-013] invalid proposal type") + errUnableToVoteOutOfPeriod = errors.New("[GNOSWAP-GOVERNANCE-014] unable to vote out of voting period") + errInvalidMessageFormat = errors.New("[GNOSWAP-GOVERNANCE-015] invalid message format") + errProposalNotPassed = errors.New("[GNOSWAP-GOVERNANCE-016] proposal not passed") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/gov/governance/execute.gno b/contract/r/gnoswap/gov/governance/execute.gno new file mode 100644 index 000000000..7c0c5336d --- /dev/null +++ b/contract/r/gnoswap/gov/governance/execute.gno @@ -0,0 +1,317 @@ +package governance + +import ( + "errors" + "std" + "strconv" + "strings" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/common" + + en "gno.land/r/gnoswap/v1/emission" +) + +const ( + EXECUTE_SEPARATOR = "*EXE*" +) + +// Function signature for different parameter handlers +type ParameterHandler func([]string) error + +// Registry for parameter handlers +type ParameterRegistry struct { + handlers *avl.Tree +} + +func NewParameterRegistry() *ParameterRegistry { + return &ParameterRegistry{ + handlers: avl.NewTree(), + } +} + +func (r *ParameterRegistry) Register( + pkgPath, function string, + hdlr ParameterHandler, +) { + key := makeHandlerKey(pkgPath, function) + r.handlers.Set(key, hdlr) +} + +func (r *ParameterRegistry) Handler(pkgPath, function string) (ParameterHandler, error) { + key := makeHandlerKey(pkgPath, function) + hdlr, exists := r.handlers.Get(key) + if !exists { + return nil, ufmt.Errorf("handler not found for %s", key) + } + return hdlr.(ParameterHandler), nil +} + +func (r *ParameterRegistry) Get(pkgPath, function string) (ParameterHandler, error) { + key := makeHandlerKey(pkgPath, function) + hdlr, exists := r.handlers.Get(key) + if !exists { + return nil, ufmt.Errorf("handler not found for %s", key) + } + return hdlr.(ParameterHandler), nil +} + +func makeHandlerKey(pkgPath, function string) string { + return ufmt.Sprintf("%s:%s", pkgPath, function) +} + +///////////////////// EXECUTION ///////////////////// +// region: Execute + +type ExecutionContext struct { + ProposalId uint64 + Now uint64 + Config *Config + Proposal *ProposalInfo + WindowStart uint64 + WindowEnd uint64 +} + +func (e *ExecutionContext) String() string { + return ufmt.Sprintf( + "ProposalId: %d, Now: %d, Config: %v, Proposal: %v, WindowStart: %d, WindowEnd: %d", + e.ProposalId, e.Now, e.Config, e.Proposal, e.WindowStart, e.WindowEnd, + ) +} + +func Execute(proposalId uint64) error { + ctx, err := prepareExecution(proposalId) + if err != nil { + panic(err) + } + + if err := validateVotes(ctx.Proposal); err != nil { + panic(err) + } + + if err := validateCommunityPoolToken(ctx.Proposal); err != nil { + panic(err) + } + + registry := createParameterHandlers() + if err := executeProposal(ctx, registry); err != nil { + return err + } + + updateProposalState(ctx) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "Execute", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "proposalId", formatUint(proposalId), + ) + + return nil +} + +func executeProposal(ctx *ExecutionContext, registry *ParameterRegistry) error { + switch ctx.Proposal.ProposalType { + case ParameterChange, CommunityPoolSpend: + return executeParameterChange(ctx.Proposal.Execution.Msgs, registry) + default: + return errUnsupportedProposalType + } +} + +func executeParameterChange(msgs []string, registry *ParameterRegistry) error { + for _, msg := range msgs { + pkgPath, function, params, err := parseMessage(msg) + if err != nil { + return err + } + + handler, err := registry.Handler(pkgPath, function) + if err != nil { + return err + } + + if err := handler(params); err != nil { + return err + } + } + + return nil +} + +func parseMessage(msg string) (pkgPath string, function string, params []string, err error) { + parts := strings.Split(msg, EXECUTE_SEPARATOR) + if len(parts) != 3 { + return "", "", nil, errInvalidMessageFormat + } + + return parts[0], parts[1], strings.Split(parts[2], ","), nil +} + +///////////////////// VALIDATION ///////////////////// + +type ExecutionValidator struct { + isTextProposal bool + isAlreadyExecuted bool + isAlreadyCanceled bool + isAlreadyRejected bool + hasPassed bool +} + +func prepareExecution(proposalId uint64) (*ExecutionContext, error) { + validateInitialState() + + proposal, err := getProposal(proposalId) + if err != nil { + return nil, err + } + + validator := validateProposalState(proposal) + if err := checkProposalValidation(validator); err != nil { + return nil, err + } + + ctx, err := createExecutionContext(proposalId, proposal) + if err != nil { + return nil, err + } + + return ctx, nil +} + +func validateInitialState() { + common.IsHalted() + + en.MintAndDistributeGns() + updateProposalsState() +} + +func getProposal(proposalId uint64) (*ProposalInfo, error) { + result, exists := proposals.Get(strconv.Itoa(int(proposalId))) + if !exists { + return nil, ufmt.Errorf("proposal %d not found", proposalId) + } + proposal, exists := result.(ProposalInfo) + if !exists { + return nil, ufmt.Errorf("proposal %d not found", proposalId) + } + return &proposal, nil +} + +func validateProposalState(proposal *ProposalInfo) ExecutionValidator { + return ExecutionValidator{ + isTextProposal: proposal.ProposalType == Text, + isAlreadyExecuted: proposal.State.Executed, + isAlreadyCanceled: proposal.State.Canceled, + isAlreadyRejected: proposal.State.Rejected, + hasPassed: proposal.State.Passed, + } +} + +func checkProposalValidation(v ExecutionValidator) error { + if v.isTextProposal { + return errTextProposalNotExecutable + } + + if v.isAlreadyExecuted || v.isAlreadyCanceled || v.isAlreadyRejected { + return errors.New("proposal already executed, canceled, or rejected") + } + + if !v.hasPassed { + return errProposalNotPassed + } + + return nil +} + +func createExecutionContext(proposalId uint64, proposal *ProposalInfo) (*ExecutionContext, error) { + now := uint64(time.Now().Unix()) + config := GetConfigVersion(proposal.ConfigVersion) + + votingEnd := calculateVotingEnd(proposal, &config) + windowStart := calculateWindowStart(votingEnd, &config) + windowEnd := calculateWindowEnd(windowStart, &config) + + if err := validateExecutionWindow(now, windowStart, windowEnd); err != nil { + return nil, err + } + + return &ExecutionContext{ + ProposalId: proposalId, + Now: now, + Config: &config, + Proposal: proposal, + WindowStart: windowStart, + WindowEnd: windowEnd, + }, nil +} + +func calculateVotingEnd(proposal *ProposalInfo, config *Config) uint64 { + return proposal.State.CreatedAt + + config.VotingStartDelay + + config.VotingPeriod +} + +func calculateWindowStart(votingEnd uint64, config *Config) uint64 { + return votingEnd + config.ExecutionDelay +} + +func calculateWindowEnd(windowStart uint64, config *Config) uint64 { + return windowStart + config.ExecutionWindow +} + +func validateExecutionWindow(now, windowStart, windowEnd uint64) error { + if now < windowStart { + return ufmt.Errorf("execution window not started (now(%d) < windowStart(%d))", now, windowStart) + } + + if now >= windowEnd { + return ufmt.Errorf("execution window over (now(%d) >= windowEnd(%d))", now, windowEnd) + } + + return nil +} + +func validateVotes(pp *ProposalInfo) error { + yea := pp.Yea.Uint64() + nea := pp.Nay.Uint64() + quorum := pp.QuorumAmount + + if yea < quorum { + return ufmt.Errorf("quorum not met (yes(%d) < quorum(%d))", yea, quorum) + } + + if yea < nea { + return ufmt.Errorf("no majority (yes(%d) < no(%d))", yea, nea) + } + + return nil +} + +func updateProposalState(ctx *ExecutionContext) { + ctx.Proposal.State.Executed = true + ctx.Proposal.State.ExecutedAt = ctx.Now + ctx.Proposal.State.Upcoming = false + ctx.Proposal.State.Active = false + proposals.Set(strconv.Itoa(int(ctx.ProposalId)), *ctx.Proposal) +} + +func validateCommunityPoolToken(pp *ProposalInfo) error { + if pp.ProposalType != CommunityPoolSpend { + return nil + } + + common.MustRegistered(pp.CommunityPoolSpend.TokenPath) + + return nil +} + +func hasDesiredParams(params []string, expected int) error { + if len(params) != expected { + return ufmt.Errorf("invalid parameters for %s. expected %d but got %d", params, expected, len(params)) + } + return nil +} diff --git a/contract/r/gnoswap/gov/governance/execute_test.gno b/contract/r/gnoswap/gov/governance/execute_test.gno new file mode 100644 index 000000000..1b4c79838 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/execute_test.gno @@ -0,0 +1,368 @@ +package governance + +import ( + "errors" + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" + "testing" +) + +func TestParameterRegistry_Register(t *testing.T) { + registry := NewParameterRegistry() + + testHandler := func(params []string) error { + return nil + } + + tests := []struct { + name string + pkgPath string + function string + handler ParameterHandler + }{ + { + name: "pass", + pkgPath: "test/pkg", + function: "testFunc", + handler: testHandler, + }, + { + name: "empty pass", + pkgPath: "", + function: "testFunc", + handler: testHandler, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry.Register(tt.pkgPath, tt.function, tt.handler) + + handler, err := registry.Handler(tt.pkgPath, tt.function) + uassert.NoError(t, err) + + if handler == nil { + t.Errorf("handler is nil") + } + + expectedKey := makeHandlerKey(tt.pkgPath, tt.function) + if _, exists := registry.handlers.Get(expectedKey); !exists { + t.Errorf("expected key %s not found", expectedKey) + } + }) + } +} + +func TestParameterRegistry(t *testing.T) { + tests := []struct { + name string + pkgPath string + function string + handler ParameterHandler + wantErr bool + errMessage string + }{ + { + name: "should register and retrieve handler successfully", + pkgPath: "test/pkg", + function: "testFunc", + handler: func(params []string) error { + return nil + }, + wantErr: false, + errMessage: "", + }, + { + name: "should return error for non-existent handler", + pkgPath: "non/existent", + function: "missing", + handler: nil, + wantErr: true, + errMessage: "handler not found for non/existent:missing", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry := NewParameterRegistry() + + // Register handler if provided + if tt.handler != nil { + registry.Register(tt.pkgPath, tt.function, tt.handler) + } + + // Try to retrieve handler + handler, err := registry.Handler(tt.pkgPath, tt.function) + + if tt.wantErr { + uassert.Error(t, err, tt.errMessage) + uassert.Equal(t, err.Error(), tt.errMessage) + } else { + uassert.NoError(t, err) + if handler == nil { + t.Error("expected handler to be non-nil") + } + } + }) + } +} + +func TestExecuteParameterChange(t *testing.T) { + registry := NewParameterRegistry() + + registry.Register("test/pkg", "TestFunc", func(params []string) error { + if len(params) != 2 { + return errors.New("invalid params length") + } + return nil + }) + + tests := []struct { + name string + msgs []string + wantErr bool + }{ + { + name: "Pass: Valid message", + msgs: []string{ + "test/pkg*EXE*TestFunc*EXE*param1,param2", + }, + wantErr: false, + }, + { + name: "Fail: Missing separator", + msgs: []string{ + "test/pkg*EXE*TestFunc", + }, + wantErr: true, + }, + { + name: "Fail: Non-existent handler", + msgs: []string{ + "unknown/pkg*EXE*UnknownFunc*EXE*param1", + }, + wantErr: true, + }, + { + name: "Fail: Not enough parameters", + msgs: []string{ + "test/pkg*EXE*TestFunc*EXE*param1", + }, + wantErr: true, + }, + { + name: "handle multiple messages", + msgs: []string{ + "test/pkg*EXE*TestFunc*EXE*param1,param2", + "test/pkg*EXE*TestFunc*EXE*param2,param3", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := executeParameterChange(tt.msgs, registry) + if !tt.wantErr { + uassert.NoError(t, err) + } else { + uassert.Error(t, err) + } + }) + } +} + +func TestValidateProposalState(t *testing.T) { + tests := []struct { + name string + proposal *ProposalInfo + want ExecutionValidator + }{ + { + name: "Pass: Text proposal", + proposal: &ProposalInfo{ + ProposalType: Text, + State: ProposalState{ + Executed: false, + Canceled: false, + Rejected: false, + Passed: true, + }, + }, + want: ExecutionValidator{ + isTextProposal: true, + isAlreadyExecuted: false, + isAlreadyCanceled: false, + isAlreadyRejected: false, + hasPassed: true, + }, + }, + { + name: "Pass: Already executed", + proposal: &ProposalInfo{ + ProposalType: ParameterChange, + State: ProposalState{ + Executed: true, + Canceled: false, + Rejected: false, + Passed: true, + }, + }, + want: ExecutionValidator{ + isTextProposal: false, + isAlreadyExecuted: true, + isAlreadyCanceled: false, + isAlreadyRejected: false, + hasPassed: true, + }, + }, + { + name: "Pass: Canceled", + proposal: &ProposalInfo{ + ProposalType: ParameterChange, + State: ProposalState{ + Executed: false, + Canceled: true, + Rejected: false, + Passed: false, + }, + }, + want: ExecutionValidator{ + isTextProposal: false, + isAlreadyExecuted: false, + isAlreadyCanceled: true, + isAlreadyRejected: false, + hasPassed: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := validateProposalState(tt.proposal) + if got != tt.want { + t.Errorf("validateProposalState() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCheckProposalValidation(t *testing.T) { + // TODO: change error message after error code is defined + tests := []struct { + name string + validator ExecutionValidator + wantErr bool + errMsg string + }{ + { + name: "Pass: Valid proposal", + validator: ExecutionValidator{ + isTextProposal: false, + isAlreadyExecuted: false, + isAlreadyCanceled: false, + isAlreadyRejected: false, + hasPassed: true, + }, + wantErr: false, + }, + { + name: "Fail: Text proposal is not executable", + validator: ExecutionValidator{ + isTextProposal: true, + isAlreadyExecuted: false, + isAlreadyCanceled: false, + isAlreadyRejected: false, + hasPassed: true, + }, + wantErr: true, + errMsg: "[GNOSWAP-GOVERNANCE-011] can not execute text proposal", + }, + { + name: "Fail: Already executed, canceled, or rejected", + validator: ExecutionValidator{ + isTextProposal: false, + isAlreadyExecuted: true, + isAlreadyCanceled: false, + isAlreadyRejected: false, + hasPassed: true, + }, + wantErr: true, + errMsg: "proposal already executed, canceled, or rejected", + }, + { + name: "Fail: Proposal has not passed", + validator: ExecutionValidator{ + isTextProposal: false, + isAlreadyExecuted: false, + isAlreadyCanceled: false, + isAlreadyRejected: false, + hasPassed: false, + }, + wantErr: true, + errMsg: "[GNOSWAP-GOVERNANCE-016] proposal not passed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := checkProposalValidation(tt.validator) + if tt.wantErr { + uassert.Error(t, err, tt.errMsg) + uassert.Equal(t, err.Error(), tt.errMsg) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestParameterRegistry2(t *testing.T) { + tests := []struct { + name string + pkgPath string + function string + params []string + setupMock func(*ParameterRegistry) + expectError bool + }{ + { + name: "valid handler", + pkgPath: consts.POOL_PATH, + function: "SetFeeProtocol", + params: []string{"1", "2"}, + setupMock: func(r *ParameterRegistry) { + r.Register(consts.POOL_PATH, "SetFeeProtocol", func(p []string) error { + return nil + }) + }, + expectError: false, + }, + { + name: "invalid handler", + pkgPath: "invalid", + function: "invalid", + params: []string{}, + setupMock: func(r *ParameterRegistry) {}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry := NewParameterRegistry() + tt.setupMock(registry) + + handler, err := registry.Handler(tt.pkgPath, tt.function) + if tt.expectError { + uassert.Error(t, err) + return + } + + err = handler(tt.params) + if err != nil { + t.Errorf("handler returned error: %v", err) + } + }) + } +} diff --git a/contract/r/gnoswap/gov/governance/fn_registry.gno b/contract/r/gnoswap/gov/governance/fn_registry.gno new file mode 100644 index 000000000..1ec3a82ec --- /dev/null +++ b/contract/r/gnoswap/gov/governance/fn_registry.gno @@ -0,0 +1,193 @@ +package governance + +import ( + "std" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" + + cp "gno.land/r/gnoswap/v1/community_pool" + en "gno.land/r/gnoswap/v1/emission" + pl "gno.land/r/gnoswap/v1/pool" + pf "gno.land/r/gnoswap/v1/protocol_fee" + rr "gno.land/r/gnoswap/v1/router" + sr "gno.land/r/gnoswap/v1/staker" + + "gno.land/r/gnoswap/v1/common" +) + +func createParameterHandlers() *ParameterRegistry { + registry := NewParameterRegistry() + + // region: Common path + registry.Register(consts.COMMON_PATH, "SetHalt", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + common.SetHalt(parseBool(params[0])) // halt + return nil + }) + + // region: Community pool + registry.Register(consts.COMMUNITY_POOL_PATH, "TransferToken", func(params []string) error { + if err := hasDesiredParams(params, 3); err != nil { + return err + } + cp.TransferToken( + params[0], // pkgPath + std.Address(params[1]), // to + parseUint64(params[2]), // amount + ) + return nil + }) + + // region: Emission + registry.Register(consts.EMISSION_PATH, "ChangeDistributionPct", func(params []string) error { + if err := hasDesiredParams(params, 8); err != nil { + return err + } + en.ChangeDistributionPct( + parseInt(params[0]), // target01 + parseUint64(params[1]), // pct01 + parseInt(params[2]), // target02 + parseUint64(params[3]), // pct02 + parseInt(params[4]), // target03 + parseUint64(params[5]), // pct03 + parseInt(params[6]), // target04 + parseUint64(params[7]), // pct04 + ) + return nil + }) + + // region: GNS Path + registry.Register(consts.GNS_PATH, "SetAvgBlockTimeInMs", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + gns.SetAvgBlockTimeInMs(int64(parseInt(params[0]))) // ms + return nil + }) + + // region: Governance Path + registry.Register(consts.GOV_GOVERNANCE_PATH, "Reconfigure", func(params []string) error { + if err := hasDesiredParams(params, 7); err != nil { + return err + } + reconfigure( + parseUint64(params[0]), // votingStartDelay + parseUint64(params[1]), // votingPeriod + parseUint64(params[2]), // votingWeightSmoothingDuration + parseUint64(params[3]), // quorum + parseUint64(params[4]), // proposalCreationhold + parseUint64(params[5]), // executionDelay + parseUint64(params[6]), // executionWindow + ) + return nil + }) + + // region: Pool Path + registry.Register(consts.POOL_PATH, "SetFeeProtocol", func(params []string) error { + if err := hasDesiredParams(params, 2); err != nil { + return err + } + pl.SetFeeProtocol( + uint8(parseUint64(params[0])), // feeProtocol0 + uint8(parseUint64(params[1])), // feeProtocol1 + ) + return nil + }) + + registry.Register(consts.POOL_PATH, "SetPoolCreationFee", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + pl.SetPoolCreationFee(parseUint64(params[0])) // fee + return nil + }) + + registry.Register(consts.POOL_PATH, "SetWithdrawalFee", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + pl.SetWithdrawalFee(parseUint64(params[0])) // fee + return nil + }) + + // region: Protocol fee + registry.Register(consts.PROTOCOL_FEE_PATH, "SetDevOpsPct", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + pf.SetDevOpsPct(parseUint64(params[0])) // pct + return nil + }) + + // region: Router + registry.Register(consts.ROUTER_PATH, "SetSwapFee", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + rr.SetSwapFee(parseUint64(params[0])) // fee + return nil + }) + + // region: Staker + registry.Register(consts.STAKER_PATH, "SetDepositGnsAmount", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + sr.SetDepositGnsAmount(parseUint64(params[0])) // amount + return nil + }) + + registry.Register(consts.STAKER_PATH, "SetPoolTier", func(params []string) error { + if err := hasDesiredParams(params, 2); err != nil { + return err + } + sr.SetPoolTier( + params[0], // pool + parseUint64(params[1]), // tier + ) + return nil + }) + + registry.Register(consts.STAKER_PATH, "ChangePoolTier", func(params []string) error { + if err := hasDesiredParams(params, 2); err != nil { + return err + } + sr.ChangePoolTier( + params[0], // pool + parseUint64(params[1]), // tier + ) + return nil + }) + + registry.Register(consts.STAKER_PATH, "RemovePoolTier", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + sr.RemovePoolTier(params[0]) // pool + return nil + }) + + registry.Register(consts.STAKER_PATH, "SetUnstakingFee", func(params []string) error { + if err := hasDesiredParams(params, 1); err != nil { + return err + } + sr.SetUnstakingFee(parseUint64(params[0])) + return nil + }) + + registry.Register(consts.STAKER_PATH, "SetWarmUp", func(params []string) error { + if err := hasDesiredParams(params, 2); err != nil { + return err + } + sr.SetWarmUp( + int64(parseInt(params[0])), // percent + int64(parseInt(params[1])), // block + ) + return nil + }) + + return registry +} diff --git a/contract/r/gnoswap/gov/governance/fn_registry_test.gno b/contract/r/gnoswap/gov/governance/fn_registry_test.gno new file mode 100644 index 000000000..9274b3016 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/fn_registry_test.gno @@ -0,0 +1,224 @@ +package governance + +import ( + "testing" + + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + + en "gno.land/r/gnoswap/v1/emission" + pl "gno.land/r/gnoswap/v1/pool" + pf "gno.land/r/gnoswap/v1/protocol_fee" + rr "gno.land/r/gnoswap/v1/router" + sr "gno.land/r/gnoswap/v1/staker" +) + +func TestCreateParameterHandlers(t *testing.T) { + registry := createParameterHandlers() + + if registry == nil { + t.Fatal("registry is nil") + } + + tests := []struct { + name string + path string + function string + params []string + wantErr bool + validate func(t *testing.T) + }{ + { + name: "Emission_ChangeDistributionPct", + path: consts.EMISSION_PATH, + function: "ChangeDistributionPct", + params: []string{"1", "7000", "2", "1500", "3", "1000", "4", "500"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, en.GetDistributionBpsPct(1), uint64(7000)) + uassert.Equal(t, en.GetDistributionBpsPct(2), uint64(1500)) + uassert.Equal(t, en.GetDistributionBpsPct(3), uint64(1000)) + uassert.Equal(t, en.GetDistributionBpsPct(4), uint64(500)) + }, + }, + { + name: "Router SetSwapFee: Pass", + path: consts.ROUTER_PATH, + function: "SetSwapFee", + params: []string{"1000"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, rr.GetSwapFee(), uint64(1000)) + }, + }, + { + name: "set deposit gns amount", + path: consts.STAKER_PATH, + function: "SetDepositGnsAmount", + params: []string{"1000"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, sr.GetDepositGnsAmount(), uint64(1000)) + }, + }, + { + name: "set pool creation fee", + path: consts.POOL_PATH, + function: "SetPoolCreationFee", + params: []string{"500"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, pl.GetPoolCreationFee(), uint64(500)) + }, + }, + { + name: "SetWithdrawalFee", + path: consts.POOL_PATH, + function: "SetWithdrawalFee", + params: []string{"600"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, pl.GetWithdrawalFee(), uint64(600)) + }, + }, + { + name: "ProtocolFee_SetDevOpsPct", + path: consts.PROTOCOL_FEE_PATH, + function: "SetDevOpsPct", + params: []string{"900"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, pf.GetDevOpsPct(), uint64(900)) + }, + }, + { + name: "Router_SetSwapFee", + path: consts.ROUTER_PATH, + function: "SetSwapFee", + params: []string{"400"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, rr.GetSwapFee(), uint64(400)) + }, + }, + { + name: "Staker_SetDepositGnsAmount", + path: consts.STAKER_PATH, + function: "SetDepositGnsAmount", + params: []string{"400"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, sr.GetDepositGnsAmount(), uint64(400)) + }, + }, + { + name: "Staker_SetUnstakingFee", + path: consts.STAKER_PATH, + function: "SetUnstakingFee", + params: []string{"100"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, sr.GetUnstakingFee(), uint64(100)) + }, + }, + { + name: "ProtocolFee_SetDevOpsPct", + path: consts.PROTOCOL_FEE_PATH, + function: "SetDevOpsPct", + params: []string{"900"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, pf.GetDevOpsPct(), uint64(900)) + }, + }, + { + name: "Router_SetSwapFee", + path: consts.ROUTER_PATH, + function: "SetSwapFee", + params: []string{"400"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, rr.GetSwapFee(), uint64(400)) + }, + }, + { + name: "Staker_SetDepositGnsAmount", + path: consts.STAKER_PATH, + function: "SetDepositGnsAmount", + params: []string{"400"}, + wantErr: false, + validate: func(t *testing.T) { + uassert.Equal(t, sr.GetDepositGnsAmount(), uint64(400)) + }, + }, + { + name: "Emission_ChangeDistributionPct_InvalidTotal_Exceed100", + path: consts.EMISSION_PATH, + function: "ChangeDistributionPct", + params: []string{"1", "8000", "2", "3000"}, + wantErr: true, + validate: func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + }() + // existing state must be preserved in the event of a panic + uassert.Equal(t, en.GetDistributionBpsPct(1), uint64(7000)) + }, + }, + { + name: "Router_SetSwapFee_InvalidValue", + path: consts.ROUTER_PATH, + function: "SetSwapFee", + params: []string{"-100"}, + wantErr: true, + validate: func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + }() + uassert.Equal(t, rr.GetSwapFee(), uint64(400)) + }, + }, + { + name: "Staker_SetUnstakingFee_TooHigh", + path: consts.STAKER_PATH, + function: "SetUnstakingFee", + params: []string{"10001"}, + wantErr: true, + validate: func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + }() + uassert.Equal(t, sr.GetUnstakingFee(), uint64(100)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } else if r != nil { + t.Fatal("unexpected panic") + } + }() + handler, _ := registry.Get(tt.path, tt.function) + err := handler(tt.params) + + if (err != nil) != tt.wantErr { + t.Errorf("unexpected error: %v", err) + } + + if !tt.wantErr { + tt.validate(t) + } + }) + } +} diff --git a/contract/r/gnoswap/gov/governance/getter_proposal.gno b/contract/r/gnoswap/gov/governance/getter_proposal.gno new file mode 100644 index 000000000..064c3cb46 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/getter_proposal.gno @@ -0,0 +1,37 @@ +package governance + +func GetProposerByProposalId(proposalId uint64) string { + return mustGetProposal(proposalId).Proposer.String() +} + +func GetProposalTypeByProposalId(proposalId uint64) string { + return mustGetProposal(proposalId).ProposalType.String() +} + +func GetYeaByProposalId(proposalId uint64) string { + return mustGetProposal(proposalId).Yea.ToString() +} + +func GetNayByProposalId(proposalId uint64) string { + return mustGetProposal(proposalId).Nay.ToString() +} + +func GetConfigVersionByProposalId(proposalId uint64) uint64 { + return mustGetProposal(proposalId).ConfigVersion +} + +func GetQuorumAmountByProposalId(proposalId uint64) uint64 { + return mustGetProposal(proposalId).QuorumAmount +} + +func GetTitleByProposalId(proposalId uint64) string { + return mustGetProposal(proposalId).Title +} + +func GetDescriptionByProposalId(proposalId uint64) string { + return mustGetProposal(proposalId).Description +} + +func GetExecutionStateByProposalId(proposalId uint64) ProposalState { + return mustGetProposal(proposalId).State +} diff --git a/contract/r/gnoswap/gov/governance/getter_vote.gno b/contract/r/gnoswap/gov/governance/getter_vote.gno new file mode 100644 index 000000000..87735cf5f --- /dev/null +++ b/contract/r/gnoswap/gov/governance/getter_vote.gno @@ -0,0 +1,40 @@ +package governance + +import ( + "std" + + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/common" +) + +func GetVoteByVoteKey(voteKey string) bool { + return mustGetVote(voteKey) +} + +func GetVoteYesByVoteKey(voteKey string) bool { + return mustGetVoteInfo(voteKey).Yes +} + +func GetVoteWeightByVoteKey(voteKey string) uint64 { + return mustGetVoteInfo(voteKey).Weight +} + +func GetVotedHeightByVoteKey(voteKey string) uint64 { + return mustGetVoteInfo(voteKey).VotedHeight +} + +func GetVotedAtByVoteKey(voteKey string) uint64 { + return mustGetVoteInfo(voteKey).VotedAt +} + +func divideVoteKeyToProposalIdAndUser(voteKey string) (proposalId uint64, user std.Address) { + parts, err := common.Split(voteKey, ":", 2) + if err != nil { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("voteKey(%s) is invalid", voteKey), + )) + } + + return parseUint64(parts[0]), std.Address(parts[1]) +} diff --git a/contract/r/gnoswap/gov/governance/gno.mod b/contract/r/gnoswap/gov/governance/gno.mod new file mode 100644 index 000000000..2501e93b5 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/gov/governance diff --git a/contract/r/gnoswap/gov/governance/proposal.gno b/contract/r/gnoswap/gov/governance/proposal.gno new file mode 100644 index 000000000..cd2346abd --- /dev/null +++ b/contract/r/gnoswap/gov/governance/proposal.gno @@ -0,0 +1,403 @@ +package governance + +import ( + "std" + "strings" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/gov/xgns" + + en "gno.land/r/gnoswap/v1/emission" + gs "gno.land/r/gnoswap/v1/gov/staker" +) + +var ( + proposalId uint64 + proposals = avl.NewTree() // proposalId -> ProposalInfo + latestProposalByProposer = avl.NewTree() // proposer -> proposalId +) + +const ( + GOV_SPLIT = "*GOV*" +) + +// ProposeText creates a new text proposal with the given data +// It checks if the proposer is eligible to create a proposal and if they don't have an active proposal. +// Returns the proposal ID +// ref: https://docs.gnoswap.io/contracts/governance/proposal.gno#proposetext +func ProposeText( + title string, + description string, +) uint64 { // proposalId + common.IsHalted() + + en.MintAndDistributeGns() + updateProposalsState() + + proposer := std.PrevRealm().Addr() + + enough, balance, wanted := checkEnoughXGnsToPropose(proposer) + if !enough { + panic(addDetailToError( + errNotEnoughBalance, + ufmt.Sprintf("proposer(%s) has not enough xGNS, balance(%d), wanted(%d)", proposer.String(), balance, wanted), + )) + } + + now := uint64(time.Now().Unix()) + // votingMax does not include quantities delegated through Launchpad. + votingMax, possibleAddressWithWeight := gs.GetPossibleVotingAddressWithWeight(now - config.VotingWeightSmoothingDuration) + + maxVotingWeight := u256.NewUint(votingMax) + quorumAmount := maxVotingWeight.Uint64() * config.Quorum / 100 + + proposal := ProposalInfo{ + Proposer: proposer, + ProposalType: Text, + State: ProposalState{ + Created: true, + CreatedAt: now, + Upcoming: true, + }, + Yea: u256.Zero(), + Nay: u256.Zero(), + MaxVotingWeight: maxVotingWeight, + PossibleAddressWithWeight: possibleAddressWithWeight, + ConfigVersion: uint64(configVersions.Size()), // use latest config version + QuorumAmount: quorumAmount, + Title: title, + Description: description, + } + + proposalId++ + proposals.Set(formatUint(proposalId), proposal) + latestProposalByProposer.Set(proposer.String(), proposalId) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "ProposeText", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "title", title, + "description", description, + "proposalId", formatUint(proposalId), + "quorumAmount", formatUint(proposal.QuorumAmount), + "maxVotingWeight", proposal.MaxVotingWeight.ToString(), + "configVersion", formatUint(proposal.ConfigVersion), + "createdAt", formatUint(proposal.State.CreatedAt), + ) + return proposalId +} + +// ProposeCommunityPoolSpend creates a new community pool spend proposal with the given data +// It checks if the proposer is eligible to create a proposal and if they don't have an active proposal. +// Returns the proposal ID +// ref: https://docs.gnoswap.io/contracts/governance/proposal.gno#proposecommunitypoolspend +func ProposeCommunityPoolSpend( + title string, + description string, + to std.Address, + tokenPath string, + amount uint64, +) uint64 { // proposalId + common.IsHalted() + + en.MintAndDistributeGns() + updateProposalsState() + + proposer := std.PrevRealm().Addr() + + enough, balance, wanted := checkEnoughXGnsToPropose(proposer) + if !enough { + panic(addDetailToError( + errNotEnoughBalance, + ufmt.Sprintf("proposer(%s) has not enough xGNS, balance(%d), wanted(%d)", proposer.String(), balance, wanted), + )) + } + + now := uint64(time.Now().Unix()) + votingMax, possibleAddressWithWeight := gs.GetPossibleVotingAddressWithWeight(now - config.VotingWeightSmoothingDuration) + + maxVotingWeight := u256.NewUint(votingMax) + quorumAmount := maxVotingWeight.Uint64() * config.Quorum / 100 + + proposal := ProposalInfo{ + Proposer: proposer, + ProposalType: CommunityPoolSpend, + State: ProposalState{ + Created: true, + CreatedAt: now, + Upcoming: true, + }, + Yea: u256.Zero(), + Nay: u256.Zero(), + MaxVotingWeight: maxVotingWeight, + PossibleAddressWithWeight: possibleAddressWithWeight, + ConfigVersion: uint64(configVersions.Size()), + QuorumAmount: quorumAmount, + Title: title, + Description: description, + CommunityPoolSpend: CommunityPoolSpendInfo{ + To: to, + TokenPath: tokenPath, + Amount: amount, + }, + } + + proposalId++ + proposals.Set(formatUint(proposalId), proposal) + latestProposalByProposer.Set(proposer.String(), proposalId) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "ProposeCommunityPoolSpend", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "title", title, + "description", description, + "to", to.String(), + "tokenPath", tokenPath, + "amount", formatUint(amount), + "proposalId", formatUint(proposalId), + "quorumAmount", formatUint(proposal.QuorumAmount), + "maxVotingWeight", proposal.MaxVotingWeight.ToString(), + "configVersion", formatUint(proposal.ConfigVersion), + "createdAt", formatUint(proposal.State.CreatedAt), + ) + + return proposalId +} + +// ProposeParameterChange creates a new parameter change with the given data +// It checks if the proposer is eligible to create a proposal and if they don't have an active proposal. +// Returns the proposal ID +// ref: https://docs.gnoswap.io/contracts/governance/proposal.gno#proposeparameterchange +func ProposeParameterChange( + title string, + description string, + numToExecute uint64, + executions string, +) uint64 { // proposalId + common.IsHalted() + + en.MintAndDistributeGns() + updateProposalsState() + + proposer := std.PrevRealm().Addr() + + enough, balance, wanted := checkEnoughXGnsToPropose(proposer) + if !enough { + panic(addDetailToError( + errNotEnoughBalance, + ufmt.Sprintf("proposer(%s) has not enough xGNS, balance(%d), wanted(%d)", proposer.String(), balance, wanted), + )) + } + + if numToExecute == 0 { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("numToExecute is 0"), + )) + } + + // check if numToExecute is a valid number + splitGov := strings.Split(executions, GOV_SPLIT) + if uint64(len(splitGov)) != numToExecute { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("numToExecute(%d) does not match the number of executions(%d)", numToExecute, len(splitGov)), + )) + } + + now := uint64(time.Now().Unix()) + votingMax, possibleAddressWithWeight := gs.GetPossibleVotingAddressWithWeight(now - config.VotingWeightSmoothingDuration) + + maxVotingWeight := u256.NewUint(votingMax) + quorumAmount := maxVotingWeight.Uint64() * config.Quorum / 100 + + proposal := ProposalInfo{ + Proposer: proposer, + ProposalType: ParameterChange, + State: ProposalState{ + Created: true, + CreatedAt: now, + Upcoming: true, + }, + Yea: u256.Zero(), + Nay: u256.Zero(), + MaxVotingWeight: maxVotingWeight, + PossibleAddressWithWeight: possibleAddressWithWeight, + ConfigVersion: uint64(configVersions.Size()), + QuorumAmount: quorumAmount, + Title: title, + Description: description, + Execution: ExecutionInfo{ + Num: numToExecute, + Msgs: splitGov, + }, + } + + proposalId++ + proposals.Set(formatUint(proposalId), proposal) + latestProposalByProposer.Set(proposer.String(), proposalId) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "ProposeParameterChange", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "title", title, + "description", description, + "numToExecute", formatUint(numToExecute), + "executions", executions, + "proposalId", formatUint(proposalId), + "quorumAmount", formatUint(proposal.QuorumAmount), + "maxVotingWeight", proposal.MaxVotingWeight.ToString(), + "configVersion", formatUint(proposal.ConfigVersion), + "createdAt", formatUint(proposal.State.CreatedAt), + ) + + return proposalId +} + +// proposalStateUpdater handles the state transitions of a proposal. +type proposalStateUpdater struct { + proposal *ProposalInfo + config Config + now uint64 +} + +// updateProposalsState updates the state of all proposals based on current time. +// It processes voting periods, results, and execution windows for each proposal. +func updateProposalsState() { + now := uint64(time.Now().Unix()) + + proposals.Iterate("", "", func(key string, value interface{}) bool { + proposal := value.(ProposalInfo) + if proposal.State.Canceled || proposal.State.Expired || proposal.State.Executed { + return false + } + + updater := newProposalStateUpdater(&proposal, now) + updater.updateVotingState() + updater.updateVotingResult() + updater.updateExecutionState() + + proposals.Set(key, *updater.proposal) + return false + }) +} + +// newProposalStateUpdater creates a new proposalStateUpdater. +func newProposalStateUpdater(proposal *ProposalInfo, now uint64) *proposalStateUpdater { + return &proposalStateUpdater{ + proposal: proposal, + config: GetConfigVersion(proposal.ConfigVersion), + now: now, + } +} + +// shouldUpdate determines if the proposal state should be updated. +// Returns true if the proposal is created, not canceled, and not executed. +func (u *proposalStateUpdater) shouldUpdate() bool { + return u.proposal.State.Created && + !u.proposal.State.Canceled && + !u.proposal.State.Executed +} + +// getVotingTimes returns the start and end timestamps of the voting period. +// The start time is CreatedAt + VotingStartDelay. +// The end time is start time + VotingPeriod. +func (u *proposalStateUpdater) getVotingTimes() (start, end uint64) { + start = u.proposal.State.CreatedAt + u.config.VotingStartDelay + end = start + u.config.VotingPeriod + return +} + +// getExecutionTimes returns the start and end timestamps of the execution window. +// The start time is after voting end + ExecutionDelay. +// The end time is start time + ExecutionWindow. +func (u *proposalStateUpdater) getExecutionTimes() (start, end uint64) { + _, votingEnd := u.getVotingTimes() + start = votingEnd + u.config.ExecutionDelay + end = start + u.config.ExecutionWindow + return +} + +// updateVotingState updates the voting state of the proposal. +// It transitions from upcoming to active state when voting period starts. +func (u *proposalStateUpdater) updateVotingState() { + if !u.shouldUpdate() { + return + } + + votingStart, votingEnd := u.getVotingTimes() + isVotingPeriod := u.now >= votingStart && u.now <= votingEnd + + if u.proposal.State.Upcoming && isVotingPeriod { + u.proposal.State.Upcoming = false + u.proposal.State.Active = true + } +} + +// updateVotingResult determines the outcome of voting when voting period ends. +// It sets the proposal as passed if it meets quorum and has more yes votes than no votes. +// Otherwise, it marks the proposal as rejected. +func (u *proposalStateUpdater) updateVotingResult() { + _, votingEnd := u.getVotingTimes() + + hasNoResult := !u.proposal.State.Passed && + !u.proposal.State.Rejected && + !u.proposal.State.Canceled + + if u.now <= votingEnd || !hasNoResult { + return + } + + yeaUint := u.proposal.Yea.Uint64() + nayUint := u.proposal.Nay.Uint64() + + if yeaUint >= u.proposal.QuorumAmount && yeaUint > nayUint { + u.proposal.State.Passed = true + u.proposal.State.PassedAt = u.now + } else { + u.proposal.State.Rejected = true + u.proposal.State.RejectedAt = u.now + } + + u.proposal.State.Upcoming = false + u.proposal.State.Active = false +} + +// updateExecutionState checks if a non-text proposal should expire. +// It marks the proposal as expired if execution window has ended. +func (u *proposalStateUpdater) updateExecutionState() { + if u.proposal.ProposalType == Text || + !u.proposal.State.Passed || + u.proposal.State.Executed || + u.proposal.State.Expired { + return + } + + _, executionEnd := u.getExecutionTimes() + + if u.now >= executionEnd { + u.proposal.State.Expired = true + u.proposal.State.ExpiredAt = u.now + } +} + +func checkEnoughXGnsToPropose(proposer std.Address) (bool, uint64, uint64) { + xGNSBalance := xgns.BalanceOf(proposer) + + if xGNSBalance < config.ProposalCreationThreshold { + return false, xGNSBalance, config.ProposalCreationThreshold + } + + return true, xGNSBalance, config.ProposalCreationThreshold +} diff --git a/contract/r/gnoswap/gov/governance/proposal_test.gno b/contract/r/gnoswap/gov/governance/proposal_test.gno new file mode 100644 index 000000000..c7f0cb884 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/proposal_test.gno @@ -0,0 +1,360 @@ +package governance + +import ( + "std" + "testing" + + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + u256 "gno.land/p/gnoswap/uint256" +) + +var ( + proposalAddr = testutils.TestAddress("proposal") + toAddr = testutils.TestAddress("to") + + voter1 = testutils.TestAddress("voter1") + voter2 = testutils.TestAddress("voter2") + + insufficientProposer = testutils.TestAddress("insufficient") +) + +func resetGlobalStateProposal(t *testing.T) { + t.Helper() + proposalId = 0 + proposals = avl.NewTree() + latestProposalByProposer = avl.NewTree() +} + +func mockCheckEnoughXGnsToPropose(proposer std.Address) (bool, uint64, uint64) { + if proposer == insufficientProposer { + return false, 500, 1000 + } + return true, 1000, 1000 +} + +func mockGetPossibleVotingAddressWithWeight(t *testing.T, timestamp uint64) (uint64, map[std.Address]uint64) { + t.Helper() + weights := make(map[std.Address]uint64) + weights[voter1] = 100 + weights[voter2] = 200 + return 300, weights +} + +func TestProposeText(t *testing.T) { + checkEnoughXGnsToPropose = mockCheckEnoughXGnsToPropose + + resetGlobalStateProposal(t) + + tests := []struct { + name string + proposer std.Address + title string + description string + expectError bool + }{ + { + name: "Valid text proposal", + proposer: proposalAddr, + title: "Test Proposal", + description: "This is a test proposal", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var pid uint64 + var err error + func() { + defer func() { + if r := recover(); r != nil { + err = r.(error) + } + }() + pid = ProposeText(tt.title, tt.description) + }() + + uassert.NoError(t, err) + + prop, exists := proposals.Get(formatUint(pid)) + uassert.True(t, exists) + + proposal := prop.(ProposalInfo) + + uassert.Equal(t, proposal.Title, tt.title) + uassert.Equal(t, proposal.Description, tt.description) + uassert.Equal(t, proposal.ProposalType.String(), Text.String()) + uassert.True(t, proposal.State.Created) + uassert.True(t, proposal.State.Upcoming) + }) + } +} + +func TestProposeCommunityPoolSpend(t *testing.T) { + checkEnoughXGnsToPropose = mockCheckEnoughXGnsToPropose + + resetGlobalStateProposal(t) + + tests := []struct { + name string + proposer std.Address + title string + description string + to std.Address + tokenPath string + amount uint64 + expectError bool + }{ + { + name: "Valid community pool spend proposal", + proposer: proposalAddr, + title: "Community Spend", + description: "Fund community initiative", + to: toAddr, + tokenPath: "token/path", + amount: 1000, + expectError: false, + }, + { + name: "Insufficient balance for proposal", + proposer: insufficientProposer, + title: "Invalid Spend", + description: "Should fail", + to: toAddr, + tokenPath: "token/path", + amount: 1000, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var pid uint64 + var err error + func() { + defer func() { + if r := recover(); r != nil { + err = r.(error) + } + }() + pid = ProposeCommunityPoolSpend( + tt.title, tt.description, tt.to, tt.tokenPath, tt.amount) + }() + + uassert.NoError(t, err) + + prop, exists := proposals.Get(formatUint(pid)) + uassert.True(t, exists) + + proposal := prop.(ProposalInfo) + + uassert.Equal(t, proposal.ProposalType.String(), CommunityPoolSpend.String()) + uassert.Equal(t, proposal.CommunityPoolSpend.To, tt.to) + uassert.Equal(t, proposal.CommunityPoolSpend.Amount, tt.amount) + }) + } +} + +func TestUpdateProposalsState(t *testing.T) { + baseTime := uint64(1000) + newConfig := Config{ + VotingStartDelay: 50, + VotingPeriod: 100, + ExecutionDelay: 50, + ExecutionWindow: 100, + } + setConfigVersion(1, newConfig) + + tests := []struct { + name string + currentTime uint64 + setupProposal func(uint64) ProposalInfo + validate func(*testing.T, ProposalInfo) + }{ + { + name: "Should reject proposal when voting ends with insufficient votes", + currentTime: baseTime + 200, + setupProposal: func(now uint64) ProposalInfo { + return ProposalInfo{ + ConfigVersion: 1, + Yea: u256.NewUint(100), + Nay: u256.NewUint(200), + QuorumAmount: 300, + State: ProposalState{ + Created: true, + CreatedAt: baseTime, + Active: true, + }, + } + }, + validate: func(t *testing.T, proposal ProposalInfo) { + uassert.True(t, proposal.State.Rejected) + uassert.False(t, proposal.State.Active) + uassert.False(t, proposal.State.Upcoming) + uassert.NotEqual(t, proposal.State.RejectedAt, uint64(0)) + }, + }, + { + name: "Should pass proposal when voting ends with sufficient votes", + currentTime: baseTime + 200, + setupProposal: func(now uint64) ProposalInfo { + return ProposalInfo{ + ConfigVersion: 1, + Yea: u256.NewUint(400), + Nay: u256.NewUint(200), + QuorumAmount: 300, + State: ProposalState{ + Created: true, + CreatedAt: baseTime, + Active: true, + }, + } + }, + validate: func(t *testing.T, proposal ProposalInfo) { + uassert.True(t, proposal.State.Passed) + uassert.False(t, proposal.State.Active) + uassert.NotEqual(t, proposal.State.PassedAt, uint64(0)) + }, + }, + { + name: "Should expire non-text proposal when execution window ends", + currentTime: baseTime + 400, + setupProposal: func(now uint64) ProposalInfo { + return ProposalInfo{ + ConfigVersion: 1, + ProposalType: ParameterChange, + State: ProposalState{ + Created: true, + CreatedAt: baseTime, + Passed: true, + PassedAt: baseTime + 200, + }, + } + }, + validate: func(t *testing.T, proposal ProposalInfo) { + uassert.True(t, proposal.State.Expired) + uassert.NotEqual(t, proposal.State.ExpiredAt, uint64(0)) + }, + }, + { + name: "Should not update canceled proposal", + currentTime: baseTime + 60, + setupProposal: func(now uint64) ProposalInfo { + return ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + Created: true, + CreatedAt: baseTime, + Canceled: true, + Upcoming: true, + }, + } + }, + validate: func(t *testing.T, proposal ProposalInfo) { + uassert.True(t, proposal.State.Canceled) + uassert.True(t, proposal.State.Upcoming) + }, + }, + { + name: "Should not expire text proposal", + currentTime: baseTime + 400, + setupProposal: func(now uint64) ProposalInfo { + return ProposalInfo{ + ConfigVersion: 1, + ProposalType: Text, + State: ProposalState{ + Created: true, + CreatedAt: baseTime, + Passed: true, + PassedAt: baseTime + 200, + }, + } + }, + validate: func(t *testing.T, proposal ProposalInfo) { + uassert.False(t, proposal.State.Expired) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proposals = avl.NewTree() + + proposal := tt.setupProposal(tt.currentTime) + proposals.Set(formatUint(uint32(1)), proposal) + + updateProposalsState() + + updatedProposal, exists := proposals.Get(formatUint(uint32(1))) + uassert.True(t, exists) + tt.validate(t, updatedProposal.(ProposalInfo)) + }) + } +} + +func TestProposeParameterChange(t *testing.T) { + resetGlobalStateProposal(t) + + tests := []struct { + name string + proposer std.Address + title string + description string + numToExecute uint64 + executions string + expectError bool + errorContains string + }{ + { + name: "Valid parameter change proposal", + proposer: proposalAddr, + title: "Change Voting Period", + description: "Update voting period to 14 days", + numToExecute: 2, + executions: "gno.land/r/gnoswap/v1/gns*EXE*SetAvgBlockTimeInMs*EXE*123*GOV*gno.land/r/gnoswap/v1/community_pool*EXE*TransferToken*EXE*gno.land/r/gnoswap/v1/gns,g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d,905", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetGlobalStateProposal(t) + + var pid uint64 + var err error + defer func() { + if r := recover(); r != nil { + t.Errorf("Unexpected error: %v", r) + } + }() + pid = ProposeParameterChange( + tt.title, + tt.description, + tt.numToExecute, + tt.executions, + ) + + uassert.NoError(t, err) + + prop, exists := proposals.Get(formatUint(pid)) + uassert.True(t, exists) + + proposal := prop.(ProposalInfo) + + uassert.Equal(t, proposal.Title, tt.title) + uassert.Equal(t, proposal.Description, tt.description) + uassert.Equal(t, proposal.ProposalType.String(), ParameterChange.String()) + + uassert.Equal(t, proposal.Execution.Num, tt.numToExecute) + uassert.Equal(t, len(proposal.Execution.Msgs), int(tt.numToExecute)) + + uassert.True(t, proposal.State.Created) + uassert.True(t, proposal.State.Upcoming) + uassert.False(t, proposal.State.Active) + uassert.True(t, proposal.Yea.IsZero()) + uassert.True(t, proposal.Nay.IsZero()) + }) + } +} diff --git a/contract/r/gnoswap/gov/governance/tests/_0_init_token_register_test.gno b/contract/r/gnoswap/gov/governance/tests/_0_init_token_register_test.gno new file mode 100644 index 000000000..c360ac658 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/tests/_0_init_token_register_test.gno @@ -0,0 +1,151 @@ +package governance + +import ( + "std" + + "gno.land/r/onbloc/foo" + + "gno.land/r/onbloc/bar" + + "gno.land/r/onbloc/baz" + + "gno.land/r/onbloc/qux" + + "gno.land/r/demo/wugnot" + + "gno.land/r/onbloc/obl" + + "gno.land/r/gnoswap/v1/gns" + + "gno.land/p/gnoswap/consts" +) + +type FooToken struct{} + +func (FooToken) Transfer() func(to std.Address, amount uint64) { + return foo.Transfer +} + +func (FooToken) TransferFrom() func(from, to std.Address, amount uint64) { + return foo.TransferFrom +} + +func (FooToken) BalanceOf() func(owner std.Address) uint64 { + return foo.BalanceOf +} + +func (FooToken) Approve() func(spender std.Address, amount uint64) { + return foo.Approve +} + +type BarToken struct{} + +func (BarToken) Transfer() func(to std.Address, amount uint64) { + return bar.Transfer +} + +func (BarToken) TransferFrom() func(from, to std.Address, amount uint64) { + return bar.TransferFrom +} + +func (BarToken) BalanceOf() func(owner std.Address) uint64 { + return bar.BalanceOf +} + +func (BarToken) Approve() func(spender std.Address, amount uint64) { + return bar.Approve +} + +type BazToken struct{} + +func (BazToken) Transfer() func(to std.Address, amount uint64) { + return baz.Transfer +} + +func (BazToken) TransferFrom() func(from, to std.Address, amount uint64) { + return baz.TransferFrom +} + +func (BazToken) BalanceOf() func(owner std.Address) uint64 { + return baz.BalanceOf +} + +func (BazToken) Approve() func(spender std.Address, amount uint64) { + return baz.Approve +} + +type QuxToken struct{} + +func (QuxToken) Transfer() func(to std.Address, amount uint64) { + return qux.Transfer +} + +func (QuxToken) TransferFrom() func(from, to std.Address, amount uint64) { + return qux.TransferFrom +} + +func (QuxToken) BalanceOf() func(owner std.Address) uint64 { + return qux.BalanceOf +} + +func (QuxToken) Approve() func(spender std.Address, amount uint64) { + return qux.Approve +} + +type WugnotToken struct{} + +func (WugnotToken) Transfer() func(to std.Address, amount uint64) { + return wugnot.Transfer +} + +func (WugnotToken) TransferFrom() func(from, to std.Address, amount uint64) { + return wugnot.TransferFrom +} + +func (WugnotToken) BalanceOf() func(owner std.Address) uint64 { + return wugnot.BalanceOf +} + +func (WugnotToken) Approve() func(spender std.Address, amount uint64) { + return wugnot.Approve +} + +type OBLToken struct{} + +func (OBLToken) Transfer() func(to std.Address, amount uint64) { + return obl.Transfer +} + +func (OBLToken) TransferFrom() func(from, to std.Address, amount uint64) { + return obl.TransferFrom +} + +func (OBLToken) BalanceOf() func(owner std.Address) uint64 { + return obl.BalanceOf +} + +func (OBLToken) Approve() func(spender std.Address, amount uint64) { + return obl.Approve +} + +type GNSToken struct{} + +func (GNSToken) Transfer() func(to std.Address, amount uint64) { + return gns.Transfer +} + +func (GNSToken) TransferFrom() func(from, to std.Address, amount uint64) { + return gns.TransferFrom +} + +func (GNSToken) BalanceOf() func(owner std.Address) uint64 { + return gns.BalanceOf +} + +func (GNSToken) Approve() func(spender std.Address, amount uint64) { + return gns.Approve +} + +func init() { + std.TestSetRealm(std.NewUserRealm(consts.TOKEN_REGISTER)) +} diff --git a/contract/r/gnoswap/gov/governance/tests/_helper_test.gno b/contract/r/gnoswap/gov/governance/tests/_helper_test.gno new file mode 100644 index 000000000..7fe3ed975 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/tests/_helper_test.gno @@ -0,0 +1,46 @@ +package governance + +import ( + "std" + + "gno.land/p/gnoswap/consts" +) + +var ( + admin std.Address = consts.ADMIN + + fooPath string = "gno.land/r/onbloc/foo" + barPath string = "gno.land/r/onbloc/bar" + bazPath string = "gno.land/r/onbloc/baz" + quxPath string = "gno.land/r/onbloc/qux" + + oblPath string = "gno.land/r/onbloc/obl" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 +) + +// Realms to mock frames +var ( + adminRealm = std.NewUserRealm(admin) + + posRealm = std.NewCodeRealm(consts.POSITION_PATH) + rouRealm = std.NewCodeRealm(consts.ROUTER_PATH) + stkRealm = std.NewCodeRealm(consts.STAKER_PATH) + govRealm = std.NewCodeRealm(consts.GOV_GOVERNANCE_PATH) +) + +/* HELPER */ +func ugnotBalanceOf(addr std.Address) uint64 { + testBanker := std.GetBanker(std.BankerTypeRealmIssue) + + coins := testBanker.GetCoins(addr) + if len(coins) == 0 { + return 0 + } + + return uint64(coins.AmountOf("ugnot")) +} diff --git a/contract/r/gnoswap/gov/governance/type.gno b/contract/r/gnoswap/gov/governance/type.gno new file mode 100644 index 000000000..3cfcb58d1 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/type.gno @@ -0,0 +1,139 @@ +package governance + +import ( + "std" + + u256 "gno.land/p/gnoswap/uint256" +) + +// Config represents the configuration of the governor contract +type Config struct { + // How long after a proposal is created does voting starrt + VotingStartDelay uint64 + + // The period during which votes are collected + VotingPeriod uint64 + + // Over how many seconds the voting weight is averaged for proposal voting as creation/cancellation threshold + VotingWeightSmoothingDuration uint64 + + // Percentags of x gns total supply at the time of proposal creation + Quorum uint64 + + // The munimum amount of average votes required to create a proposal + ProposalCreationThreshold uint64 + + // How much time must pass after the end of a voting period before a proposal can be executed + ExecutionDelay uint64 + + // The amount of time after the execution delay that the proposal can be executed + ExecutionWindow uint64 +} + +// ProposalState represents the state of a proposal's execution +type ProposalState struct { + Created bool + CreatedAt uint64 + + Upcoming bool + Active bool + + Passed bool + PassedAt uint64 + + Rejected bool + RejectedAt uint64 + + Canceled bool + CanceledAt uint64 + + // LABEL + Executed bool + ExecutedAt uint64 + + Expired bool + ExpiredAt uint64 +} + +type CommunityPoolSpendInfo struct { + To std.Address + TokenPath string + Amount uint64 +} + +type ExecutionInfo struct { + Num uint64 + Msgs []string // split by *GOV* +} + +type ParameterChangeInfo struct { + PkgPath string + Function string + Params string +} + +type ProposalType string + +const ( + Text ProposalType = "TEXT" + CommunityPoolSpend ProposalType = "COMMUNITY_POOL_SPEND" + ParameterChange ProposalType = "PARAMETER_CHANGE" +) + +func tryParseProposalType(v string) (ProposalType, error) { + switch v { + case "TEXT": + return Text, nil + case "COMMUNITY_POOL_SPEND": + return CommunityPoolSpend, nil + case "PARAMETER_CHANGE": + return ParameterChange, nil + default: + return "", errInvalidProposalType + } +} + +func (p ProposalType) String() string { + return string(p) +} + +// ProposalInfo represents all the information about a proposal +type ProposalInfo struct { + // The address of the proposer + Proposer std.Address + + // Text, CommunityPoolSpend, ParameterChange + ProposalType ProposalType + + // The execution state of the proposal + State ProposalState + + // How many yes votes have been collected + Yea *u256.Uint + + // How many no votes have been collected + Nay *u256.Uint + + // The max voting weight at the time of proposal creation + MaxVotingWeight *u256.Uint + + PossibleAddressWithWeight map[std.Address]uint64 + + // The version of the config that this proposal was created with + ConfigVersion uint64 + + // How many total votes must be collected for the proposal + QuorumAmount uint64 + + // The title of the proposal + Title string + + // The description of the proposal + Description string + + // CommunityPoolSpend + CommunityPoolSpend CommunityPoolSpendInfo + + // Execution (ParamterChange) + Execution ExecutionInfo +} diff --git a/contract/r/gnoswap/gov/governance/utils.gno b/contract/r/gnoswap/gov/governance/utils.gno new file mode 100644 index 000000000..986e4c885 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/utils.gno @@ -0,0 +1,151 @@ +package governance + +import ( + "encoding/base64" + "std" + "strconv" + + "gno.land/p/demo/avl" + + "gno.land/p/demo/json" + "gno.land/p/demo/ufmt" +) + +func mustGetProposal(proposalId uint64) ProposalInfo { + result, exists := proposals.Get(formatUint(proposalId)) + if !exists { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("proposalId(%d) not found", proposalId), + )) + } + + return result.(ProposalInfo) +} + +func mustGetVote(key string) bool { + vote, exists := votes.Get(key) + if !exists { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("voteKey(%s) not found", key), + )) + } + return vote.(bool) +} + +// Helper function to validate and get vote information +func getVoteInfoFromKey(voteKey string) (voteWithWeight, bool) { + mustGetVote(voteKey) + + pid, addr := divideVoteKeyToProposalIdAndUser(voteKey) + + voteInfo, exists := getUserVote(addr, pid) + if !exists { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("voteKey(%s) not found", voteKey), + )) + } + + return voteInfo, true +} + +func mustGetVoteInfo(voteKey string) voteWithWeight { + voteInfo, _ := getVoteInfoFromKey(voteKey) + return voteInfo +} + +func iterTree(tree *avl.Tree, cb func(key string, value interface{}) bool) { + tree.Iterate("", "", cb) +} + +func strToInt(str string) int { + res, err := strconv.Atoi(str) + if err != nil { + panic(err.Error()) + } + + return res +} + +func voteToString(b bool) string { + if b { + return "yes" + } + return "no" +} + +func marshal(data *json.Node) string { + b, err := json.Marshal(data) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func b64Encode(data string) string { + return string(base64.StdEncoding.EncodeToString([]byte(data))) +} + +func b64Decode(data string) string { + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + panic(err.Error()) + } + return string(decoded) +} + +func prevRealm() string { + return std.PrevRealm().PkgPath() +} + +func getPrev() (string, string) { + prev := std.PrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +func formatUint(v interface{}) string { + switch v := v.(type) { + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatBool(v bool) string { + return strconv.FormatBool(v) +} + +func parseInt(s string) int { + num, err := strconv.ParseInt(s, 10, 64) + if err != nil { + panic(ufmt.Sprintf("invalid int value: %s", s)) + } + return int(num) +} + +func parseUint64(s string) uint64 { + num, err := strconv.ParseUint(s, 10, 64) + if err != nil { + panic(ufmt.Sprintf("invalid uint64 value: %s", s)) + } + return num +} + +func parseBool(s string) bool { + switch s { + case "true": + return true + case "false": + return false + default: + panic(ufmt.Sprintf("invalid bool value: %s", s)) + } +} diff --git a/contract/r/gnoswap/gov/governance/vote.gno b/contract/r/gnoswap/gov/governance/vote.gno new file mode 100644 index 000000000..d1fc77915 --- /dev/null +++ b/contract/r/gnoswap/gov/governance/vote.gno @@ -0,0 +1,260 @@ +package governance + +import ( + "std" + "time" + + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/demo/avl" + + en "gno.land/r/gnoswap/v1/emission" + + "gno.land/r/gnoswap/v1/common" + + "gno.land/p/demo/ufmt" +) + +type voteWithWeight struct { + Yes bool + Weight uint64 + VotedHeight uint64 + VotedAt uint64 +} + +var ( + votes = avl.NewTree() // voteKey(proposalId:user) -> yes/no + userVotes = avl.NewTree() // user -> proposalId -> voteWithWeight + accumYesVotes = make(map[uint64]uint64) // proposalId -> yes + accumNoVotes = make(map[uint64]uint64) // proposalId -> no +) + +func createVoteKey(pid uint64, voter string) string { + return ufmt.Sprintf("%d:%s", pid, voter) +} + +func getVote(pid uint64, voter string) (bool, bool) { + value, exists := votes.Get(createVoteKey(pid, voter)) + if !exists { + return false, false + } + return value.(bool), true +} + +func setVote(pid uint64, voter string, vote bool) { + voteKey := createVoteKey(pid, voter) + votes.Set(voteKey, vote) +} + +func createUserVoteKey(voter std.Address, pid uint64) string { + return ufmt.Sprintf("%s:%d", voter.String(), pid) +} + +func getUserVote(voter std.Address, pid uint64) (voteWithWeight, bool) { + value, exists := userVotes.Get(createUserVoteKey(voter, pid)) + if !exists { + return voteWithWeight{}, false + } + return value.(voteWithWeight), true +} + +func setUserVote(voter std.Address, pid uint64, vote voteWithWeight) { + if vote.Yes { + accumYesVotes[pid] += vote.Weight + } else { + accumNoVotes[pid] += vote.Weight + } + userVotes.Set(createUserVoteKey(voter, pid), vote) +} + +///////////////////// Vote ///////////////////// + +// Vote allows a user to vote on a given proposal. +// The user's voting weight is determined by their accumulated delegated stake until proposal creation time. +// ref: https://docs.gnoswap.io/contracts/governance/vote.gno#vote +func Vote(pid uint64, yes bool) string { + common.IsHalted() + en.MintAndDistributeGns() + updateProposalsState() + + proposal := mustGetProposal(pid) + + voter := std.PrevRealm().Addr() + now := uint64(time.Now().Unix()) + + voteKey := createVoteKey(pid, voter.String()) + + state, err := newVoteState(proposal, voteKey, voter, now) + if err != nil { + panic(err) + } + + if err := state.validate(); err != nil { + panic(err) + } + + executor := newVoteExecutor(&proposal, voter, state.userWeight) + if err := executor.execute(yes); err != nil { + panic(err) + } + + proposals.Set(formatUint(pid), proposal) + + setVote(pid, voter.String(), yes) + setUserVote(voter, pid, voteWithWeight{ + Yes: yes, + Weight: state.userWeight, + VotedHeight: uint64(std.GetHeight()), + VotedAt: now, + }) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "Vote", + "prevAddr", prevAddr, + "prevPkgPath", prevPkgPath, + "proposalId", formatUint(pid), + "voter", voter.String(), + "yes", voteToString(yes), + "voteWeight", formatUint(state.userWeight), + "voteYes", formatUint(accumYesVotes[pid]), + "voteNo", formatUint(accumNoVotes[pid]), + ) + + return voteKey +} + +///////////////////// Vote State ///////////////////// + +type VoteState struct { + isVotingPeriod bool + isProposalValid bool + hasUserVoted bool + userWeight uint64 +} + +func newVoteState(proposal ProposalInfo, voteKey string, voter std.Address, now uint64) (*VoteState, error) { + cfg := GetConfigVersion(proposal.ConfigVersion) + votingStartTime := proposal.State.CreatedAt + cfg.VotingStartDelay + votingEndTime := votingStartTime + cfg.VotingPeriod + + hasUserVoted := false + voted, exists := votes.Get(voteKey) + if exists { + hasUserVoted = voted.(bool) + } + + return &VoteState{ + isVotingPeriod: now >= votingStartTime && now < votingEndTime, + isProposalValid: !proposal.State.Canceled, + hasUserVoted: hasUserVoted, + userWeight: proposal.PossibleAddressWithWeight[voter], + }, nil +} + +func (s *VoteState) validate() error { + if !s.isVotingPeriod { + return errUnableToVoteOutOfPeriod + } + if !s.isProposalValid { + return errUnableToVoteCanceledProposal + } + if s.hasUserVoted { + return errAlreadyVoted + } + if s.userWeight == 0 { + return errNotEnoughVotingWeight + } + return nil +} + +///////////////////// Vote Execution ///////////////////// + +type VoteExecutor struct { + proposal *ProposalInfo + voter std.Address + weight uint64 +} + +func newVoteExecutor(proposal *ProposalInfo, voter std.Address, weight uint64) *VoteExecutor { + return &VoteExecutor{ + proposal: proposal, + voter: voter, + weight: weight, + } +} + +func (e *VoteExecutor) execute(yes bool) error { + if yes { + e.proposal.Yea = safeVoteSum(e.proposal.Yea, e.weight) + } else { + e.proposal.Nay = safeVoteSum(e.proposal.Nay, e.weight) + } + return nil +} + +func safeVoteSum(collected *u256.Uint, weight uint64) *u256.Uint { + newSum, overflow := new(u256.Uint).AddOverflow(collected, u256.NewUint(weight)) + if overflow { + panic(errOutOfRange) + } + return newSum +} + +// Cancel cancels the proposal with the given ID. +// Only callable by the proposer or if the proposer's stake has fallen below the threshold others can call. +// ref: https://docs.gnoswap.io/contracts/governance/vote.gno#cancel +func Cancel(proposalId uint64) { + common.IsHalted() + + en.MintAndDistributeGns() + updateProposalsState() + + proposal := mustGetProposal(proposalId) + if proposal.State.Canceled { + panic(addDetailToError( + errAlreadyCanceledProposal, + ufmt.Sprintf("proposalId(%d) has already canceled", proposalId), + )) + } + + config := GetConfigVersion(proposal.ConfigVersion) + now := uint64(time.Now().Unix()) + if now >= (proposal.State.CreatedAt + config.VotingStartDelay) { + panic(addDetailToError( + errUnableToCancleVotingProposal, + ufmt.Sprintf("voting has already started for proposalId(%d)", proposalId), + )) + } + + caller := std.PrevRealm().Addr() + if caller != proposal.Proposer { + // If the caller is not the proposer, check if the proposer's stake has fallen below the threshold + enough, balance, wanted := checkEnoughXGnsToPropose(proposal.Proposer) + if enough { + panic(addDetailToError( + errUnableToCancelProposalWithVoterEnoughDelegated, + ufmt.Sprintf( + "caller(%s) is not the proposer(%s) and proposer's xgns balance(%d) is above the threshold(%d)", + caller, proposal.Proposer, + balance, wanted, + ), + )) + } + } + + proposal.State.Canceled = true + proposal.State.CanceledAt = now + proposal.State.Upcoming = false + proposal.State.Active = false + + proposals.Set(formatUint(proposalId), proposal) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "Cancel", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "proposalId", formatUint(proposalId), + ) +} diff --git a/contract/r/gnoswap/gov/governance/vote_test.gno b/contract/r/gnoswap/gov/governance/vote_test.gno new file mode 100644 index 000000000..23bbe6dcb --- /dev/null +++ b/contract/r/gnoswap/gov/governance/vote_test.gno @@ -0,0 +1,511 @@ +package governance + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + u256 "gno.land/p/gnoswap/uint256" +) + +func TestVote(t *testing.T) { + baseTime := uint64(time.Now().Unix()) + voter := testutils.TestAddress("voter") + + newConfig := Config{ + VotingStartDelay: 50, + VotingPeriod: 100, + } + setConfigVersion(1, newConfig) + + testCases := []struct { + name string + setup func() uint64 + pid uint64 + yes bool + expectError bool + validate func(t *testing.T, voteKey string) + }{ + { + name: "Successful YES vote", + setup: func() uint64 { + pid := uint64(1) + proposals.Set(formatUint(pid), ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + Created: true, + CreatedAt: baseTime - 60, // Voting period started + }, + Yea: u256.NewUint(0), + Nay: u256.NewUint(0), + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }) + return pid + }, + yes: true, + expectError: false, + validate: func(t *testing.T, voteKey string) { + prop, exists := proposals.Get(formatUint(1)) + if !exists { + t.Error("Proposal was not stored") + return + } + proposal := prop.(ProposalInfo) + if proposal.Yea.Cmp(u256.NewUint(100)) != 0 { + t.Errorf("Expected Yea votes to be 100, got %v", proposal.Yea) + } + if proposal.Nay.Cmp(u256.NewUint(0)) != 0 { + t.Errorf("Expected Nay votes to be 0, got %v", proposal.Nay) + } + + value, exists := votes.Get(voteKey) + if !exists || !value.(bool) { + t.Error("Vote record not properly stored") + } + + // Verify user vote record + userVote, exists := userVotes.Get(voter.String()) + if !exists { + t.Error("User vote record not properly stored") + } + + weight := userVote.(voteWithWeight) + uassert.True(t, weight.Yes) + uassert.Equal(t, weight.Weight, uint64(100)) + }, + }, + { + name: "Successful NO vote", + setup: func() uint64 { + pid := uint64(2) + proposals.Set(formatUint(pid), ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + Created: true, + CreatedAt: baseTime - 60, + }, + Yea: u256.NewUint(0), + Nay: u256.NewUint(0), + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }) + return pid + }, + yes: false, + expectError: false, + validate: func(t *testing.T, voteKey string) { + prop, exists := proposals.Get(formatUint(2)) + if !exists { + t.Error("Proposal was not stored") + return + } + proposal := prop.(ProposalInfo) + uassert.Equal(t, proposal.Yea.ToString(), "0") + uassert.Equal(t, proposal.Nay.ToString(), "100") + }, + }, + { + name: "Vote before voting period starts", + setup: func() uint64 { + pid := uint64(3) + proposals.Set(formatUint(pid), ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + Created: true, + CreatedAt: baseTime, // Just created, voting hasn't started + }, + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }) + return pid + }, + expectError: true, + }, + { + name: "Vote on canceled proposal", + setup: func() uint64 { + pid := uint64(5) + proposals.Set(formatUint(pid), ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + Created: true, + CreatedAt: baseTime - 60, + Canceled: true, + }, + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }) + return pid + }, + expectError: true, + }, + { + name: "Double voting attempt", + setup: func() uint64 { + pid := uint64(6) + proposals.Set(formatUint(pid), ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + Created: true, + CreatedAt: baseTime - 60, + }, + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }) + // Pre-existing vote + voteKey := createVoteKey(pid, voter.String()) + votes.Set(voteKey, true) + return pid + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resetVoteEnv(t) + + pid := tc.setup() + + var voteKey string + var err error + func() { + defer func() { + if r := recover(); r != nil { + err = r.(error) + } + }() + voteKey = Vote(pid, tc.yes) + }() + + if tc.expectError { + uassert.Error(t, err) + return + } + }) + } +} + +func TestNewVoteState(t *testing.T) { + voter := testutils.TestAddress("voter") + voteKey := createVoteKey(1, voter.String()) + + // test config + // configVersions = map[uint64]Config{ + // 1: { + // VotingStartDelay: 50, + // VotingPeriod: 100, + // }, + // } + newConfig := Config{ + VotingStartDelay: 50, + VotingPeriod: 100, + } + setConfigVersion(1, newConfig) + + tests := []struct { + name string + proposal ProposalInfo + now uint64 + expectedState *VoteState + expectError bool + }{ + { + name: "Valid propoal: in voting period", + proposal: ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + // voting start: 900 + 50 = 950 + // voting end: 950 + 100 = 1050 + CreatedAt: 900, + Canceled: false, + }, + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }, + now: 1000, + expectedState: &VoteState{ + isVotingPeriod: true, + isProposalValid: true, + hasUserVoted: false, + userWeight: 100, + }, + expectError: false, + }, + { + name: "Before voting start", + proposal: ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + CreatedAt: 900, + Canceled: false, + }, + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }, + now: 940, + expectedState: &VoteState{ + isVotingPeriod: false, + isProposalValid: true, + hasUserVoted: false, + userWeight: 100, + }, + expectError: false, + }, + { + name: "After voting end", + proposal: ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + CreatedAt: 900, + Canceled: false, + }, + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }, + now: 1060, + expectedState: &VoteState{ + isVotingPeriod: false, + isProposalValid: true, + hasUserVoted: false, + userWeight: 100, + }, + expectError: false, + }, + { + name: "Canceled proposal", + proposal: ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + CreatedAt: 900, + Canceled: true, + }, + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }, + now: 1000, + expectedState: &VoteState{ + isVotingPeriod: true, + isProposalValid: false, + hasUserVoted: false, + userWeight: 100, + }, + expectError: false, + }, + { + name: "No voting permmisions", + proposal: ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + CreatedAt: 900, + Canceled: false, + }, + PossibleAddressWithWeight: map[std.Address]uint64{}, + }, + now: 1000, + expectedState: &VoteState{ + isVotingPeriod: true, + isProposalValid: true, + hasUserVoted: false, + userWeight: 0, + }, + expectError: false, + }, + { + name: "Already voted", + proposal: ProposalInfo{ + ConfigVersion: 1, + State: ProposalState{ + CreatedAt: 900, + Canceled: false, + }, + PossibleAddressWithWeight: map[std.Address]uint64{ + voter: 100, + }, + }, + now: 1000, + expectedState: &VoteState{ + isVotingPeriod: true, + isProposalValid: true, + hasUserVoted: true, + userWeight: 100, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.expectedState.hasUserVoted { + votes.Set(voteKey, true) + } else { + votes.Remove(voteKey) + } + + state, err := newVoteState(tt.proposal, voteKey, voter, tt.now) + + if tt.expectError { + uassert.Error(t, err) + } + + if err != nil { + uassert.NoError(t, err) + } + + uassert.Equal(t, state.isVotingPeriod, tt.expectedState.isVotingPeriod) + uassert.Equal(t, state.isProposalValid, tt.expectedState.isProposalValid) + uassert.Equal(t, state.hasUserVoted, tt.expectedState.hasUserVoted) + uassert.Equal(t, state.userWeight, tt.expectedState.userWeight) + }) + } +} + +func TestNewVoteExecutor(t *testing.T) { + voter := testutils.TestAddress("voter") + proposal := &ProposalInfo{ + Yea: u256.NewUint(0), + Nay: u256.NewUint(0), + } + weight := uint64(100) + + executor := newVoteExecutor(proposal, voter, weight) + + if executor == nil { + t.Fatal("Expected non-nil VoteExecutor") + } + + if executor.proposal != proposal { + t.Error("Proposal reference mismatch") + } + + uassert.Equal(t, executor.voter, voter) + uassert.Equal(t, executor.weight, weight) +} + +func TestVoteExecutor_Execute(t *testing.T) { + tests := []struct { + name string + yes bool + initYea uint64 + initNay uint64 + weight uint64 + expectYea uint64 + expectNay uint64 + expectError bool + }{ + { + name: "Valid YES vote with zero initial votes", + yes: true, + initYea: 0, + initNay: 0, + weight: 100, + expectYea: 100, + expectNay: 0, + expectError: false, + }, + { + name: "Valid NO vote with zero initial votes", + yes: false, + initYea: 0, + initNay: 0, + weight: 100, + expectYea: 0, + expectNay: 100, + expectError: false, + }, + { + name: "YES vote with existing votes", + yes: true, + initYea: 150, + initNay: 50, + weight: 100, + expectYea: 250, + expectNay: 50, + expectError: false, + }, + { + name: "NO vote with existing votes", + yes: false, + initYea: 150, + initNay: 50, + weight: 100, + expectYea: 150, + expectNay: 150, + expectError: false, + }, + { + name: "YES vote with maximum safe weight", + yes: true, + initYea: 0, + initNay: 0, + weight: ^uint64(0), // max uint64 + expectYea: ^uint64(0), + expectNay: 0, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + voter := testutils.TestAddress("voter") + proposal := &ProposalInfo{ + Yea: u256.NewUint(tt.initYea), + Nay: u256.NewUint(tt.initNay), + } + + executor := newVoteExecutor(proposal, voter, tt.weight) + + var err error + func() { + defer func() { + if r := recover(); r != nil { + err = r.(error) + } + }() + err = executor.execute(tt.yes) + }() + + if tt.expectError { + if err == nil { + t.Error("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if proposal.Yea.Cmp(u256.NewUint(tt.expectYea)) != 0 { + t.Errorf("Yea votes mismatch: got %v, want %v", + proposal.Yea, tt.expectYea) + } + + if proposal.Nay.Cmp(u256.NewUint(tt.expectNay)) != 0 { + t.Errorf("Nay votes mismatch: got %v, want %v", + proposal.Nay, tt.expectNay) + } + }) + } +} + +func resetVoteEnv(t *testing.T) { + t.Helper() + votes = avl.NewTree() + userVotes = avl.NewTree() +} diff --git a/contract/r/gnoswap/gov/staker/api_delegation.gno b/contract/r/gnoswap/gov/staker/api_delegation.gno new file mode 100644 index 000000000..52c2a51e4 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/api_delegation.gno @@ -0,0 +1,68 @@ +package staker + +import ( + "std" + + "gno.land/p/demo/avl" + + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/gnoswap/v1/gov/xgns" +) + +// GetTotalxGnsSupply returns the total amount of xGNS supply. +func GetTotalxGnsSupply() uint64 { + return xgns.TotalSupply() +} + +// GetTotalVoteWeight returns the total amount of xGNS used for voting. +func GetTotalVoteWeight() uint64 { + return xgns.VotingSupply() +} + +// GetTotalDelegated returns the total amount of xGNS delegated. +func GetTotalDelegated() uint64 { + return totalDelegated +} + +// GetTotalLockedAmount returns the total amount of locked GNS. +func GetTotalLockedAmount() uint64 { + return lockedAmount +} + +// GetTotalDelegatedFrom returns the total amount of xGNS delegated by given address. +func GetTotalDelegatedFrom(from std.Address) uint64 { + amount, exist := delegatorAmount.Get(from.String()) + if !exist { + return 0 + } + return amount.(uint64) +} + +// GetTotalDelegatedTo returns the total amount of xGNS delegated to given address. +func GetTotalDelegatedTo(to std.Address) uint64 { + amount, exist := delegatedTo.Get(to.String()) + if !exist { + return 0 + } + return amount.(uint64) +} + +// GetDelegationAmountFromTo returns the amount of xGNS delegated by given address to given address. +func GetDelegationAmountFromTo(from, to std.Address) uint64 { + toAmount, exist := delegatedFromTo.Get(from.String()) + if !exist { + return 0 + } + + amount, exist := toAmount.(*avl.Tree).Get(to.String()) + if !exist { + return 0 + } + + return amount.(uint64) +} + +// GetRealmGnsBalance returns the amount of GNS in the current realm. +func GetRealmGnsBalance() uint64 { + return gns.BalanceOf(std.CurrentRealm().Addr()) +} diff --git a/contract/r/gnoswap/gov/staker/api_history.gno b/contract/r/gnoswap/gov/staker/api_history.gno new file mode 100644 index 000000000..fd33b7a79 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/api_history.gno @@ -0,0 +1,65 @@ +package staker + +import ( + "std" + "time" + + "gno.land/p/demo/json" + "gno.land/p/demo/ufmt" +) + +// GetPossibleVotingAddressWithWeight returns the max voting weight + and possible voting address with weight +func GetPossibleVotingAddressWithWeight(endTimestamp uint64) (uint64, map[std.Address]uint64) { + if endTimestamp > uint64(time.Now().Unix()) { + panic(addDetailToError( + errFutureTime, + ufmt.Sprintf("endTimestamp(%d) > now(%d)", endTimestamp, time.Now().Unix()), + )) + } + + totalWeight := uint64(0) + addressWithWeight := make(map[std.Address]uint64) + + delegationSnapShotHistory.Iterate("", "", func(key string, value interface{}) bool { + history := value.([]DelegationSnapShotHistory) + toAddr := std.Address(key) + + for i := len(history) - 1; i >= 0; i-- { + record := history[i] + + if record.updatedAt > endTimestamp { + continue + } + + addressWithWeight[toAddr] = record.amount + totalWeight += record.amount + + break + } + + return false + }) + + return totalWeight, addressWithWeight +} + +// GetPossibleVotingAddressWithWeightJSON returns the max voting weight + and possible voting address with weight(string json format) +func GetPossibleVotingAddressWithWeightJSON(endTimestamp uint64) (uint64, string) { + totalWeight, addressWithWeight := GetPossibleVotingAddressWithWeight(endTimestamp) + + possibleObj := json.ObjectNode("", nil) + possibleObj.AppendObject("height", json.StringNode("height", ufmt.Sprintf("%d", std.GetHeight()))) + possibleObj.AppendObject("now", json.StringNode("now", ufmt.Sprintf("%d", time.Now().Unix()))) + + possibleArr := json.ArrayNode("votingPower", nil) + for to, weight := range addressWithWeight { + addrWithWeightObj := json.ObjectNode("", nil) + addrWithWeightObj.AppendObject("address", json.StringNode("address", to.String())) + addrWithWeightObj.AppendObject("weight", json.StringNode("weight", ufmt.Sprintf("%d", weight))) + possibleArr.AppendArray(addrWithWeightObj) + } + + possibleObj.AppendObject("votingPower", possibleArr) + + return totalWeight, marshal(possibleObj) +} diff --git a/contract/r/gnoswap/gov/staker/api_staker.gno b/contract/r/gnoswap/gov/staker/api_staker.gno new file mode 100644 index 000000000..523dde3ee --- /dev/null +++ b/contract/r/gnoswap/gov/staker/api_staker.gno @@ -0,0 +1,89 @@ +package staker + +import ( + "std" + "time" + + "gno.land/p/demo/json" + + en "gno.land/r/gnoswap/v1/emission" +) + +// GetLockedAmount gets the total locked amount of GNS. +func GetLockedAmount() uint64 { + en.MintAndDistributeGns() + + return lockedAmount +} + +// GetLockedInfoByAddress gets the locked info of an address. +// - total locked amount +// - claimable amount +func GetLockedInfoByAddress(addr std.Address) string { + en.MintAndDistributeGns() + + lockeds, exist := addrLockedGns.Get(addr.String()) + if !exist { + return "" + } + + now := uint64(time.Now().Unix()) + + totalLocked := uint64(0) + claimableAmount := uint64(0) + + for _, locked := range lockeds.([]lockedGNS) { + amount := locked.amount + unlock := locked.unlock + + totalLocked += amount + + if now >= unlock { + claimableAmount += amount + } + } + + data := json.Builder(). + WriteString("height", formatInt(std.GetHeight())). + WriteString("now", formatInt(time.Now().Unix())). + WriteString("totalLocked", formatUint(totalLocked)). + WriteString("claimableAmount", formatUint(claimableAmount)). + Node() + + return marshal(data) +} + +func GetClaimableRewardByAddress(addr std.Address) string { + en.MintAndDistributeGns() + + rewardState.finalize(getCurrentBalance(), getCurrentProtocolFeeBalance()) + + emissionReward, protocolFeeRewards := rewardState.CalculateReward(addr) + + if emissionReward == 0 && len(protocolFeeRewards) == 0 { + return "" + } + + data := json.Builder(). + WriteString("height", formatInt(std.GetHeight())). + WriteString("now", formatInt(time.Now().Unix())). + WriteString("emissionReward", formatUint(emissionReward)). + Node() + + if len(protocolFeeRewards) > 0 { + pfArr := json.ArrayNode("", nil) + for tokenPath, protocolFeeReward := range protocolFeeRewards { + if protocolFeeReward > 0 { + pfObj := json.Builder(). + WriteString("tokenPath", tokenPath). + WriteString("amount", formatUint(protocolFeeReward)). + Node() + pfArr.AppendArray(pfObj) + } + } + + data.AppendObject("protocolFees", pfArr) + } + + return marshal(data) +} diff --git a/contract/r/gnoswap/gov/staker/api_staker_test.gno b/contract/r/gnoswap/gov/staker/api_staker_test.gno new file mode 100644 index 000000000..c94a099ab --- /dev/null +++ b/contract/r/gnoswap/gov/staker/api_staker_test.gno @@ -0,0 +1,125 @@ +package staker + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/json" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" +) + +func TestGetLockedInfoByAddress(t *testing.T) { + addr := testutils.TestAddress("locked_info_test") + + now := uint64(time.Now().Unix()) + lockedGNSList := []lockedGNS{ + {amount: 1000, unlock: now - 100}, // already unlocked + {amount: 2000, unlock: now + 100}, // still locked + {amount: 3000, unlock: now - 50}, // already unlocked + } + addrLockedGns.Set(addr.String(), lockedGNSList) + + result := GetLockedInfoByAddress(addr) + + node := json.Must(json.Unmarshal([]byte(result))) + + uassert.True(t, node.HasKey("height")) + uassert.True(t, node.HasKey("now")) + uassert.True(t, node.HasKey("totalLocked")) + uassert.True(t, node.HasKey("claimableAmount")) + + // summation of all the locked quantities + totalLocked, err := node.MustKey("totalLocked").GetString() + uassert.NoError(t, err) + uassert.Equal(t, totalLocked, "6000") // 1000 + 2000 + 3000 = 6000 + + claimableAmount, err := node.MustKey("claimableAmount").GetString() + uassert.NoError(t, err) + uassert.Equal(t, claimableAmount, "4000") // 1000 + 3000 = 4000 +} + +func TestGetLockedInfoByAddress_NoLocks(t *testing.T) { + addr := testutils.TestAddress("no_locks_test") + + // no quantities are locked here + result := GetLockedInfoByAddress(addr) + uassert.Equal(t, result, "") +} + +func TestGetLockedInfoByAddress_EmptyLocks(t *testing.T) { + addr := testutils.TestAddress("empty_locks_test") + + addrLockedGns.Set(addr.String(), []lockedGNS{}) + + result := GetLockedInfoByAddress(addr) + + node := json.Must(json.Unmarshal([]byte(result))) + uassert.True(t, node.HasKey("height")) + uassert.True(t, node.HasKey("now")) + uassert.True(t, node.HasKey("totalLocked")) + uassert.True(t, node.HasKey("claimableAmount")) + + totalLocked, err := node.MustKey("totalLocked").GetString() + uassert.NoError(t, err) + uassert.Equal(t, totalLocked, "0") + + claimableAmount, err := node.MustKey("claimableAmount").GetString() + uassert.NoError(t, err) + uassert.Equal(t, claimableAmount, "0") +} + +func TestGetClaimableRewardByAddress(t *testing.T) { + addr := testutils.TestAddress("claimable_test") + + rewardState.AddStake(uint64(std.GetHeight()), addr, 100, 0, nil) + + currentGNSBalance = 1000 + // userEmissionReward.Set(addr.String(), uint64(1000)) + + currentProtocolFeeBalance["token1:token2"] = 500 + currentProtocolFeeBalance["token2:token3"] = 300 + //pfTree := avl.NewTree() + //pfTree.Set("token1:token2", uint64(500)) + //pfTree.Set("token2:token3", uint64(300)) + //userProtocolFeeReward.Set(addr.String(), pfTree) + + result := GetClaimableRewardByAddress(addr) + + node := json.Must(json.Unmarshal([]byte(result))) + + uassert.True(t, node.HasKey("height")) + uassert.True(t, node.HasKey("now")) + + emissionReward, err := node.MustKey("emissionReward").GetString() + uassert.NoError(t, err) + uassert.Equal(t, emissionReward, "1000") + + protocolFees := node.MustKey("protocolFees") + uassert.True(t, protocolFees.IsArray()) + + protocolFees.ArrayEach(func(i int, fee *json.Node) { + tokenPath, err := fee.MustKey("tokenPath").GetString() + uassert.NoError(t, err) + + amount, err := fee.MustKey("amount").GetString() + uassert.NoError(t, err) + + switch tokenPath { + case "token1:token2": + uassert.Equal(t, amount, "500") + case "token2:token3": + uassert.Equal(t, amount, "300") + default: + t.Errorf("unexpected tokenPath: %s", tokenPath) + } + }) +} + +func TestGetClaimableRewardByAddress_NoRewards(t *testing.T) { + addr := testutils.TestAddress("no_reward_test") + + result := GetClaimableRewardByAddress(addr) + uassert.Equal(t, result, "") +} diff --git a/contract/r/gnoswap/gov/staker/clean_delegation_stat_history.gno b/contract/r/gnoswap/gov/staker/clean_delegation_stat_history.gno new file mode 100644 index 000000000..d85f30d75 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/clean_delegation_stat_history.gno @@ -0,0 +1,72 @@ +package staker + +import ( + "std" + + "gno.land/p/gnoswap/consts" +) + +// default one day +var thresholdVotingWeightBlockHeight = consts.TIMESTAMP_DAY / consts.BLOCK_GENERATION_INTERVAL + +var ( + lastCleanedHeight uint64 = 0 + running bool = true +) + +func CleanDelegationStatHistoryByAdmin() { + assertCallerIsAdmin() + cleanDelegationStatHistory() +} + +func GetRunning() bool { + return running +} + +func SetRunning(run bool) { + assertCallerIsAdmin() + + running = run +} + +func GetThresholdVotingWeightBlockHeight() uint64 { + return uint64(thresholdVotingWeightBlockHeight) +} + +func SetThresholdVotingWeightBlockHeightByAdmin(height uint64) { + assertCallerIsAdmin() + + thresholdVotingWeightBlockHeight = int64(height) +} + +func cleanDelegationStatHistory() { + height := uint64(std.GetHeight()) + sinceLast := height - lastCleanedHeight + + if sinceLast < uint64(thresholdVotingWeightBlockHeight) { + return + } + + lastCleanedHeight = height + + // delete history older than 1 day, but keep the latest one + keepFrom := height - uint64(thresholdVotingWeightBlockHeight) + + delegationSnapShotHistory.Iterate("", "", func(key string, value interface{}) bool { + history := value.([]DelegationSnapShotHistory) + + // reverse history + for i := len(history) - 1; i >= 0; i-- { + if history[i].updatedBlock > keepFrom { + continue + } + + // save truncated history + newHistory := history[i:] + delegationSnapShotHistory.Set(key, newHistory) + break + } + + return false + }) +} diff --git a/contract/r/gnoswap/gov/staker/clean_delegation_stat_history_test.gno b/contract/r/gnoswap/gov/staker/clean_delegation_stat_history_test.gno new file mode 100644 index 000000000..1d05132bd --- /dev/null +++ b/contract/r/gnoswap/gov/staker/clean_delegation_stat_history_test.gno @@ -0,0 +1,84 @@ +package staker + +import ( + "std" + "testing" + + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" +) + +var ( + testAddr1 = testutils.TestAddress("test1") + testAddr = testutils.TestAddress("test") +) + +type mockEnv struct { + height uint64 + isAdmin bool +} + +func (m *mockEnv) GetHeight() int64 { + return int64(m.height) +} + +func (m *mockEnv) IsAdmin() bool { + return m.isAdmin +} + +func TestCleanDelegationStatHistory(t *testing.T) { + mock := &mockEnv{height: 1000, isAdmin: true} + std.TestSetOrigCaller(testAddr1) + delegationSnapShotHistory = avl.NewTree() + + addr := testAddr.String() + history := []DelegationSnapShotHistory{ + {updatedBlock: 500}, // Old + {updatedBlock: 900}, // Within threshold + {updatedBlock: 950}, // Latest + } + delegationSnapShotHistory.Set(addr, history) + + tests := []struct { + name string + setupHeight uint64 + lastCleaned uint64 + threshold int64 + expectedLen int + }{ + { + name: "no clean needed", + setupHeight: 1000, + lastCleaned: 999, + threshold: 100, + expectedLen: 3, + }, + { + name: "clean old records", + setupHeight: 1000, + lastCleaned: 800, + threshold: 100, + expectedLen: 3, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mock.height = tc.setupHeight + lastCleanedHeight = tc.lastCleaned + thresholdVotingWeightBlockHeight = tc.threshold + + cleanDelegationStatHistory() + + value, exists := delegationSnapShotHistory.Get(addr) + if !exists { + t.Fatal("history should exist") + } + + history := value.([]DelegationSnapShotHistory) + if len(history) != tc.expectedLen { + t.Errorf("expected history length %d, got %d", tc.expectedLen, len(history)) + } + }) + } +} diff --git a/contract/r/gnoswap/gov/staker/delegate_undelegate.gno b/contract/r/gnoswap/gov/staker/delegate_undelegate.gno new file mode 100644 index 000000000..466546c2c --- /dev/null +++ b/contract/r/gnoswap/gov/staker/delegate_undelegate.gno @@ -0,0 +1,180 @@ +package staker + +import ( + "std" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" +) + +var ( + totalDelegated = uint64(0) + + delegatorAmount = avl.NewTree() // caller => amount + delegatedFromTo = avl.NewTree() // caller => to => amount + delegatedTo = avl.NewTree() // to => amount +) + +func delegate(to std.Address, amount uint64) { + caller := std.PrevRealm().Addr().String() + toStr := to.String() + + // initialize the internal tree for callers to `delegatedFromTo` + innerTree := getOrCreateInnerTree(delegatedFromTo, caller) + + totalDelegated += amount + + // update delegator amount + updateUint64InTree(delegatorAmount, caller, amount, true) + + // update delegatedFromTo's inner tree + updateUint64InTree(innerTree, toStr, amount, true) + + // update delegatedTo + updateUint64InTree(delegatedTo, toStr, amount, true) + + timeStamp := uint64(time.Now().Unix()) + // update delegation history + delegation := DelegationHistory{ + to: to, + amount: amount, + timestamp: timeStamp, + height: uint64(std.GetHeight()), + add: true, // if true, delegation + } + + history := make([]DelegationHistory, 0) + if value, exists := delegationHistory.Get(caller); exists { + history = value.([]DelegationHistory) + } + history = append(history, delegation) + delegationHistory.Set(caller, history) + + // update delegation stat history + updateAmount := uint64(0) + snapShotHistory := make([]DelegationSnapShotHistory, 0) + if value, exists := delegationSnapShotHistory.Get(toStr); exists { + snapShotHistory = value.([]DelegationSnapShotHistory) + lastStat := snapShotHistory[len(snapShotHistory)-1] + updateAmount = lastStat.amount + amount + } else { + updateAmount = amount + } + + snapShotHistory = append(snapShotHistory, DelegationSnapShotHistory{ + to: to, + amount: updateAmount, + updatedBlock: uint64(std.GetHeight()), + updatedAt: timeStamp, + }) + delegationSnapShotHistory.Set(toStr, snapShotHistory) +} + +func undelegate(to std.Address, amount uint64) { + caller := std.PrevRealm().Addr().String() + toStr := to.String() + + // check caller's delegatedFromTo + innerTree, exists := delegatedFromTo.Get(caller) + if !exists { + panic(addDetailToError( + errNoDelegatedAmount, + ufmt.Sprintf("caller(%s) has no delegated amount", caller), + )) + } + // check caller's delegatedFromTo's inner tree + delegatedAmountValue, exists := innerTree.(*avl.Tree).Get(toStr) + if !exists { + panic(addDetailToError( + errNoDelegatedTarget, + ufmt.Sprintf("caller(%s) has no delegated amount to %s", caller, to), + )) + } + delegatedAmount := delegatedAmountValue.(uint64) + if delegatedAmount < amount { + panic(addDetailToError( + errNotEnoughDelegated, + ufmt.Sprintf("caller(%s) has only %d delegated amount(request: %d) to %s", caller, delegatedAmount, amount, to), + )) + } + innerTree.(*avl.Tree).Set(toStr, delegatedAmount-amount) + delegatedFromTo.Set(caller, innerTree) + + // update total delegated amount + totalDelegated -= amount + + currentAmount := uint64(0) + if value, exists := delegatorAmount.Get(caller); exists { + currentAmount = value.(uint64) + } else { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("caller(%s) has no delegated amount", caller), + )) + } + if currentAmount < amount { + panic(addDetailToError( + errNotEnoughDelegated, + ufmt.Sprintf("caller(%s) has only %d delegated amount(request: %d)", caller, currentAmount, amount), + )) + } + delegatorAmount.Set(caller, currentAmount-amount) + + currentToAmount := uint64(0) + if value, exists := delegatedTo.Get(toStr); exists { + currentToAmount = value.(uint64) + } else { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("to(%s) has no delegated amount", toStr), + )) + } + if currentToAmount < amount { + panic(addDetailToError( + errNotEnoughDelegated, + ufmt.Sprintf("to(%s) has only %d delegated amount(request: %d)", toStr, currentToAmount, amount), + )) + } + delegatedTo.Set(toStr, currentToAmount-amount) + + // update delegation history + delegation := DelegationHistory{ + to: to, + amount: amount, + timestamp: uint64(time.Now().Unix()), + height: uint64(std.GetHeight()), + add: false, + } + var history []DelegationHistory + if value, exists := delegationHistory.Get(caller); exists { + history = value.([]DelegationHistory) + } + history = append(history, delegation) + delegationHistory.Set(caller, history) + + // update delegation stat history + statValue, exists := delegationSnapShotHistory.Get(toStr) + if !exists { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("caller(%s) has no delegation stat history", caller), + )) + } + + stat := statValue.([]DelegationSnapShotHistory) + remainingAmount := amount + for i := 0; i < len(stat); i++ { + if stat[i].amount > 0 { + if stat[i].amount < remainingAmount { + remainingAmount -= stat[i].amount + stat = append(stat[:i], stat[i+1:]...) + } else { + stat[i].amount -= remainingAmount + stat[i].updatedAt = uint64(time.Now().Unix()) + break + } + } + } + delegationSnapShotHistory.Set(to.String(), stat) +} diff --git a/contract/r/gnoswap/gov/staker/delegate_undelegate_test.gno b/contract/r/gnoswap/gov/staker/delegate_undelegate_test.gno new file mode 100644 index 000000000..9430786ee --- /dev/null +++ b/contract/r/gnoswap/gov/staker/delegate_undelegate_test.gno @@ -0,0 +1,110 @@ +package staker + +import ( + "std" + "testing" + + "gno.land/p/demo/avl" +) + +func TestDelegateInternal(t *testing.T) { + t.Skip("Must running separately. because this test depends on TestUndelegate's state") + std.TestSetOrigCaller(testAddr1) + resetState() + + addr1 := testAddr1 + + t.Run("success - first delegation", func(t *testing.T) { + delegate(addr1, 100) + + if totalDelegated != 100 { + t.Errorf("expected totalDelegated 100, got %d", totalDelegated) + } + + value, exists := delegatorAmount.Get(addr1.String()) + if !exists || value.(uint64) != 100 { + t.Error("delegator amount not updated correctly") + } + + innerTree, exists := delegatedFromTo.Get(addr1.String()) + if !exists { + t.Error("delegatedFromTo not updated") + } + + inner := innerTree.(*avl.Tree) + delegatedAmount, exists := inner.Get(addr1.String()) + if !exists { + t.Error("delegatedFromTo amount incorrect") + } + + if delegatedAmount.(uint64) != 100 { + t.Error("delegatedFromTo amount incorrect") + } + }) + + t.Run("success - additional delegation", func(t *testing.T) { + resetState() + delegate(addr1, 100) + delegate(addr1, 50) + + if totalDelegated != 150 { + t.Errorf("expected totalDelegated 150, got %d", totalDelegated) + } + }) +} + +func TestUndelegateInternal(t *testing.T) { + t.Skip("Must running separately. because this test depends on TestUndelegate's state") + addr1 := testAddr1 + std.TestSetOrigCaller(addr1) + resetState() + + t.Run("fail - no delegation", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for no delegation") + } + }() + undelegate(addr1, 100) + }) + + t.Run("fail - insufficient amount", func(t *testing.T) { + delegate(addr1, 50) + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for insufficient amount") + } + }() + undelegate(addr1, 100) + }) + + t.Run("success - partial undelegate", func(t *testing.T) { + resetState() + delegate(addr1, 100) + undelegate(addr1, 30) + + if totalDelegated != 70 { + t.Errorf("expected totalDelegated 70, got %d", totalDelegated) + } + }) + + t.Run("success - full undelegate", func(t *testing.T) { + resetState() + delegate(addr1, 100) + undelegate(addr1, 100) + + if totalDelegated != 0 { + t.Errorf("expected totalDelegated 0, got %d", totalDelegated) + } + }) +} + +func resetState() { + totalDelegated = 0 + delegatorAmount = avl.NewTree() + delegatedFromTo = avl.NewTree() + delegatedTo = avl.NewTree() + delegationHistory = avl.NewTree() + delegationSnapShotHistory = avl.NewTree() +} diff --git a/contract/r/gnoswap/gov/staker/errors.gno b/contract/r/gnoswap/gov/staker/errors.gno new file mode 100644 index 000000000..5d67cc9a8 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/errors.gno @@ -0,0 +1,36 @@ +package staker + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-GOV_STAKER-001] caller has no permission") + errDataNotFound = errors.New("[GNOSWAP-GOV_STAKER-002] requested data not found") + errTransferFailed = errors.New("[GNOSWAP-GOV_STAKER-003] transfer failed") + errInvalidAmount = errors.New("[GNOSWAP-GOV_STAKER-004] invalid amount") + errNoDelegatedAmount = errors.New("[GNOSWAP-GOV_STAKER-005] zero delegated amount") + errNoDelegatedTarget = errors.New("[GNOSWAP-GOV_STAKER-006] did not delegated to that address") + errNotEnoughDelegated = errors.New("[GNOSWAP-GOV_STAKER-007] not enough delegated") + errInvalidAddress = errors.New("[GNOSWAP-GOV_STAKER-008] invalid address") + errFutureTime = errors.New("[GNOSWAP-GOV_STAKER-009] can not use future time") + errNotEnoughBalance = errors.New("[GNOSWAP-GOV_STAKER-010] not enough balance") + errLessThanMinimum = errors.New("[GNOSWAP-GOV_STAKER-011] can not delegate less than minimum amount") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} + +// checkTransferError checks transfer error. +func checkTransferError(err error) { + if err != nil { + panic(addDetailToError( + errTransferFailed, + err.Error(), + )) + } +} diff --git a/contract/r/gnoswap/gov/staker/gno.mod b/contract/r/gnoswap/gov/staker/gno.mod new file mode 100644 index 000000000..7caca4df0 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/gov/staker diff --git a/contract/r/gnoswap/gov/staker/history.gno b/contract/r/gnoswap/gov/staker/history.gno new file mode 100644 index 000000000..53adc7c55 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/history.gno @@ -0,0 +1,72 @@ +package staker + +import ( + "std" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + + en "gno.land/r/gnoswap/v1/emission" +) + +// DelegationHistory represents a single delegation event +type DelegationHistory struct { + to std.Address + amount uint64 + timestamp uint64 + height uint64 + add bool +} + +// DelegationSnapShotHistory represents delegation stat for to address +type DelegationSnapShotHistory struct { + to std.Address + amount uint64 + updatedBlock uint64 + updatedAt uint64 +} + +var ( + delegationHistory = avl.NewTree() // addr => []delegationHistory + delegationSnapShotHistory = avl.NewTree() // addr => []delegationSnapShotHistory +) + +// GetDelegatedCumulative gets the cumulative delegated amount for an address at a certain timestamp. +func GetDelegatedCumulative(delegator std.Address, endTimestamp uint64) uint64 { + en.MintAndDistributeGns() + + if !delegator.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("invalid delegator address: %s", delegator.String()), + )) + } + + if endTimestamp > uint64(time.Now().Unix()) { + panic(addDetailToError( + errFutureTime, + ufmt.Sprintf("endTimestamp(%d) > now(%d)", endTimestamp, time.Now().Unix()), + )) + } + + history := make([]DelegationSnapShotHistory, 0) + if value, exists := delegationSnapShotHistory.Get(delegator.String()); exists { + history = value.([]DelegationSnapShotHistory) + } else { + return 0 + } + + // reverse history + for i := len(history) - 1; i >= 0; i-- { + record := history[i] + + if record.updatedAt > endTimestamp { + continue + } + + return record.amount // return last accu amount + } + + return 0 +} diff --git a/contract/r/gnoswap/gov/staker/history_test.gno b/contract/r/gnoswap/gov/staker/history_test.gno new file mode 100644 index 000000000..3498b5cb5 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/history_test.gno @@ -0,0 +1,103 @@ +package staker + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" +) + +func TestGetDelegatedCumulative(t *testing.T) { + delegationSnapShotHistory = avl.NewTree() + + addr1 := testutils.TestAddress("test1") + now := uint64(time.Now().Unix()) + + tests := []struct { + name string + setupHistory []DelegationSnapShotHistory + delegator std.Address + endTimestamp uint64 + expectAmount uint64 + expectPanic bool + }{ + { + name: "no history returns zero", + delegator: addr1, + endTimestamp: now, + expectAmount: 0, + }, + { + name: "single history before timestamp", + setupHistory: []DelegationSnapShotHistory{ + { + to: addr1, + amount: 100, + updatedBlock: 1, + updatedAt: now - 100, + }, + }, + delegator: addr1, + endTimestamp: now, + expectAmount: 100, + }, + { + name: "multiple histories returns latest before timestamp", + setupHistory: []DelegationSnapShotHistory{ + { + to: addr1, + amount: 100, + updatedBlock: 1, + updatedAt: now - 200, + }, + { + to: addr1, + amount: 150, + updatedBlock: 2, + updatedAt: now - 100, + }, + { + to: addr1, + amount: 200, + updatedBlock: 3, + updatedAt: now + 100, // Future update + }, + }, + delegator: addr1, + endTimestamp: now, + expectAmount: 150, + }, + { + name: "future timestamp panics", + delegator: addr1, + endTimestamp: now + 1000, + expectPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + delegationSnapShotHistory = avl.NewTree() + + if len(tt.setupHistory) > 0 { + delegationSnapShotHistory.Set(tt.delegator.String(), tt.setupHistory) + } + + if tt.expectPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but got none") + } + }() + } + + result := GetDelegatedCumulative(tt.delegator, tt.endTimestamp) + + if !tt.expectPanic && result != tt.expectAmount { + t.Errorf("expected amount %d but got %d", tt.expectAmount, result) + } + }) + } +} \ No newline at end of file diff --git a/contract/r/gnoswap/gov/staker/reward_calculation.gno b/contract/r/gnoswap/gov/staker/reward_calculation.gno new file mode 100644 index 000000000..b3f152d13 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/reward_calculation.gno @@ -0,0 +1,334 @@ +package staker + +import ( + "std" + + en "gno.land/r/gnoswap/v1/emission" + pf "gno.land/r/gnoswap/v1/protocol_fee" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/p/demo/avl" + ufmt "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" +) + +var ( + currentGNSBalance uint64 = 0 + currentProtocolFeeBalance map[string]uint64 = make(map[string]uint64) +) + +func getCurrentBalance() uint64 { + gotGnsForEmission = en.GetDistributedToGovStaker() + en.ClearDistributedToGovStaker() + + currentGNSBalance += gotGnsForEmission + + return currentGNSBalance +} + +func getCurrentProtocolFeeBalance() map[string]uint64 { + pf.DistributeProtocolFee() + gotAccuProtocolFee := pf.GetAccuTransferToGovStaker() + + gotAccuProtocolFee.Iterate("", "", func(key string, value interface{}) bool { + amount := value.(uint64) + currentProtocolFeeBalance[key] += amount + return false + }) + + pf.ClearAccuTransferToGovStaker() + + return currentProtocolFeeBalance +} + +type StakerRewardInfo struct { + StartHeight uint64 // height when staker started staking + PriceDebt *u256.Uint // price debt per xGNS stake, Q128 + ProtocolFeePriceDebt map[string]*u256.Uint // protocol fee debt per xGNS stake, Q128 + Amount uint64 // amount of xGNS staked + Claimed uint64 // amount of GNS reward claimed so far + ProtocolFeeClaimed map[string]uint64 // protocol fee amount claimed per token +} + +func (self *StakerRewardInfo) Debug() string { + return ufmt.Sprintf("{ StartHeight: %d, PriceDebt: %d, Amount: %d, Claimed: %d }", self.StartHeight, self.PriceDebtUint64(), self.Amount, self.Claimed) +} + +func (self *StakerRewardInfo) PriceDebtUint64() uint64 { + return u256.Zero().Rsh(self.PriceDebt, 128).Uint64() +} + +type RewardState struct { + // CurrentBalance is sum of all the previous balances, including the reward distribution. + CurrentBalance uint64 // current balance of gov_staker, used to calculate RewardAccumulation + CurrentProtocolFeeBalance map[string]uint64 // current protocol fee balance per token + PriceAccumulation *u256.Uint // claimable GNS per xGNS stake, Q128 + ProtocolFeePriceAccumulation map[string]*u256.Uint // protocol fee debt per xGNS stake, Q128 + TotalStake uint64 // total xGNS staked + + info *avl.Tree // address -> StakerRewardInfo +} + +func NewRewardState() *RewardState { + return &RewardState{ + info: avl.NewTree(), + CurrentBalance: 0, + CurrentProtocolFeeBalance: make(map[string]uint64), + PriceAccumulation: u256.Zero(), + ProtocolFeePriceAccumulation: make(map[string]*u256.Uint), + TotalStake: 0, + } +} + +var rewardState = NewRewardState() + +func (self *RewardState) Debug() string { + return ufmt.Sprintf("{ CurrentBalance: %d, PriceAccumulation: %d, TotalStake: %d, info: len(%d) }", self.CurrentBalance, self.PriceAccumulationUint64(), self.TotalStake, self.info.Size()) +} + +func (self *RewardState) Info(staker std.Address) StakerRewardInfo { + infoI, exists := self.info.Get(staker.String()) + if !exists { + return StakerRewardInfo{ + StartHeight: uint64(std.GetHeight()), + PriceDebt: u256.Zero(), + ProtocolFeePriceDebt: make(map[string]*u256.Uint), + Amount: 0, + Claimed: 0, + ProtocolFeeClaimed: make(map[string]uint64), + } + } + return infoI.(StakerRewardInfo) +} + +func (self *RewardState) CalculateReward(staker std.Address) (uint64, map[string]uint64) { + info := self.Info(staker) + stakerPrice := u256.Zero().Sub(self.PriceAccumulation, info.PriceDebt) + reward := stakerPrice.Mul(stakerPrice, u256.NewUint(info.Amount)) + reward = reward.Rsh(reward, 128) + + protocolFeeRewards := make(map[string]uint64) + for tokenPath, protocolFeePriceAccumulation := range self.ProtocolFeePriceAccumulation { + protocolFeePriceDebt, ok := info.ProtocolFeePriceDebt[tokenPath] + if !ok { + protocolFeePriceDebt = u256.Zero() + } + protocolFeePrice := u256.Zero().Sub(protocolFeePriceAccumulation, protocolFeePriceDebt) + protocolFeeReward := protocolFeePrice.Mul(protocolFeePrice, u256.NewUint(info.Amount)) + protocolFeeReward = protocolFeeReward.Rsh(protocolFeeReward, 128) + protocolFeeReward64 := protocolFeeReward.Uint64() + if protocolFeeReward64 > 0 { + protocolFeeRewards[tokenPath] = protocolFeeReward64 + } + } + return reward.Uint64(), protocolFeeRewards +} + +func (self *RewardState) PriceAccumulationUint64() uint64 { + return u256.Zero().Rsh(self.PriceAccumulation, 128).Uint64() +} + +// amount MUST be less than or equal to the amount of xGNS staked +// This function does not check it +func (self *RewardState) deductReward(staker std.Address, currentBalance uint64) (uint64, map[string]uint64) { + info := self.Info(staker) + stakerPrice := u256.Zero().Sub(self.PriceAccumulation, info.PriceDebt) + reward := stakerPrice.Mul(stakerPrice, u256.NewUint(info.Amount)) + reward = reward.Rsh(reward, 128) + reward64 := reward.Uint64() + info.Claimed += reward64 + + protocolFeeRewards := make(map[string]uint64) + println("protocolfeeaccumulation", self.ProtocolFeePriceAccumulation) + println("protocolfeepricedebt", info.ProtocolFeePriceDebt) + println("protocolfeeclaimed", info.ProtocolFeeClaimed) + for tokenPath, protocolFeePriceAccumulation := range self.ProtocolFeePriceAccumulation { + protocolFeePriceDebt, ok := info.ProtocolFeePriceDebt[tokenPath] + if !ok { + protocolFeePriceDebt = u256.Zero() + } + protocolFeePrice := u256.Zero().Sub(protocolFeePriceAccumulation, protocolFeePriceDebt) + protocolFeeReward := protocolFeePrice.Mul(protocolFeePrice, u256.NewUint(info.Amount)) + protocolFeeReward = protocolFeeReward.Rsh(protocolFeeReward, 128) + protocolFeeRewards[tokenPath] = protocolFeeReward.Uint64() + info.ProtocolFeeClaimed[tokenPath] += protocolFeeReward.Uint64() + } + + self.info.Set(staker.String(), info) + + self.CurrentBalance = currentBalance - reward64 + for tokenPath, amount := range protocolFeeRewards { + self.CurrentProtocolFeeBalance[tokenPath] -= amount + } + + return reward64, protocolFeeRewards +} + +// This function MUST be called as a part of AddStake or RemoveStake +// CurrentBalance / StakeChange / IsRemoveStake will be updated in those functions +func (self *RewardState) finalize(currentBalance uint64, currentProtocolFeeBalances map[string]uint64) { + if self.TotalStake == uint64(0) { + // no staker + return + } + + delta := currentBalance - self.CurrentBalance + price := u256.NewUint(delta) + price = price.Lsh(price, 128) + price = price.Div(price, u256.NewUint(self.TotalStake)) + self.PriceAccumulation.Add(self.PriceAccumulation, price) + self.CurrentBalance = currentBalance + + for tokenPath, currentProtocolFeeBalance := range currentProtocolFeeBalances { + protocolFeeDelta := currentProtocolFeeBalance - self.CurrentProtocolFeeBalance[tokenPath] + protocolFeePrice := u256.NewUint(protocolFeeDelta) + protocolFeePrice = protocolFeePrice.Lsh(protocolFeePrice, 128) + protocolFeePrice = protocolFeePrice.Div(protocolFeePrice, u256.NewUint(self.TotalStake)) + protocolFeePriceAccumulation, ok := self.ProtocolFeePriceAccumulation[tokenPath] + if !ok { + protocolFeePriceAccumulation = u256.Zero() + } + protocolFeePriceAccumulation.Add(protocolFeePriceAccumulation, protocolFeePrice) + self.ProtocolFeePriceAccumulation[tokenPath] = protocolFeePriceAccumulation + self.CurrentProtocolFeeBalance[tokenPath] = currentProtocolFeeBalance + } +} + +func (self *RewardState) AddStake(currentHeight uint64, staker std.Address, amount uint64, currentBalance uint64, currentProtocolFeeBalances map[string]uint64) { + self.finalize(currentBalance, currentProtocolFeeBalances) + + self.TotalStake += amount + + if self.info.Has(staker.String()) { + info := self.Info(staker) + info.PriceDebt.Add(info.PriceDebt, u256.NewUint(info.Amount)) + info.PriceDebt.Add(info.PriceDebt, u256.Zero().Mul(self.PriceAccumulation, u256.NewUint(amount))) + info.PriceDebt.Div(info.PriceDebt, u256.NewUint(self.TotalStake)) + for tokenPath, amount := range currentProtocolFeeBalances { + protocolFeePriceDebt, ok := info.ProtocolFeePriceDebt[tokenPath] + if !ok { + info.ProtocolFeePriceDebt[tokenPath] = self.ProtocolFeePriceAccumulation[tokenPath].Clone() + continue + } + protocolFeePriceDebt.Add(protocolFeePriceDebt, u256.NewUint(amount)) + protocolFeePriceDebt.Add(protocolFeePriceDebt, u256.Zero().Mul(self.ProtocolFeePriceAccumulation[tokenPath], u256.NewUint(amount))) + protocolFeePriceDebt.Div(protocolFeePriceDebt, u256.NewUint(self.TotalStake)) + info.ProtocolFeePriceDebt[tokenPath] = protocolFeePriceDebt + } + info.Amount += amount + self.info.Set(staker.String(), info) + return + } + + info := StakerRewardInfo{ + StartHeight: currentHeight, + PriceDebt: self.PriceAccumulation.Clone(), + Amount: amount, + Claimed: 0, + ProtocolFeeClaimed: make(map[string]uint64), + ProtocolFeePriceDebt: make(map[string]*u256.Uint), + } + + for tokenPath, protocolFeePriceAccumulation := range self.ProtocolFeePriceAccumulation { + info.ProtocolFeePriceDebt[tokenPath] = protocolFeePriceAccumulation.Clone() + } + + self.info.Set(staker.String(), info) +} + +func (self *RewardState) Claim(staker std.Address, currentBalance uint64, currentProtocolFeeBalances map[string]uint64) (uint64, map[string]uint64) { + if !self.info.Has(staker.String()) { + return 0, make(map[string]uint64) + } + + println("Claim", staker.String()) + + self.finalize(currentBalance, currentProtocolFeeBalances) + + println("Claim2", staker.String()) + + reward, protocolFeeRewards := self.deductReward(staker, currentBalance) + + println("Claim3", staker.String()) + + return reward, protocolFeeRewards +} + +func (self *RewardState) RemoveStake(staker std.Address, amount uint64, currentBalance uint64, currentProtocolFeeBalances map[string]uint64) (uint64, map[string]uint64) { + self.finalize(currentBalance, currentProtocolFeeBalances) + + reward, protocolFeeRewards := self.deductReward(staker, currentBalance) + + self.info.Remove(staker.String()) + + self.TotalStake -= amount + + return reward, protocolFeeRewards +} + +var ( + q96 = u256.MustFromDecimal(consts.Q96) + lastCalculatedHeight uint64 // flag to prevent same block calculation +) + +var ( + gotGnsForEmission uint64 + leftGnsEmissionFromLast uint64 + alreadyCalculatedGnsEmission uint64 + + leftProtocolFeeFromLast = avl.NewTree() // tokenPath -> tokenAmount + alreadyCalculatedProtocolFee = avl.NewTree() // tokenPath -> tokenAmount +) + +// === LAUNCHPAD DEPOSIT +var ( + amountByProjectWallet = avl.NewTree() // recipient wallet => amount + rewardByProjectWallet = avl.NewTree() // recipient wallet => reward +) + +func GetRewardByProjectWallet(addr std.Address) uint64 { + value, exists := rewardByProjectWallet.Get(addr.String()) + if !exists { + return 0 + } + return value.(uint64) +} + +func getAmountByProjectWallet(addr std.Address) uint64 { + value, exists := amountByProjectWallet.Get(addr.String()) + if !exists { + return 0 + } + return value.(uint64) +} + +func SetAmountByProjectWallet(addr std.Address, amount uint64, add bool) { + caller := std.PrevRealm().Addr() + if err := common.LaunchpadOnly(caller); err != nil { + panic(err) + } + + common.IsHalted() + en.MintAndDistributeGns() + + currentAmount := getAmountByProjectWallet(addr) + if add { + amountByProjectWallet.Set(addr.String(), currentAmount+amount) + rewardState.AddStake(uint64(std.GetHeight()), caller, amount, getCurrentBalance(), getCurrentProtocolFeeBalance()) + } else { + amountByProjectWallet.Set(addr.String(), currentAmount-amount) + rewardState.RemoveStake(caller, amount, getCurrentBalance(), getCurrentProtocolFeeBalance()) + } +} + +// calculateXGnsRatio calculates the ratio of user's xGNS amount to total xGNS supply +func calculateXGnsRatio(amount uint64, xGnsX96 *u256.Uint) *u256.Uint { + xGnsAmountX96 := new(u256.Uint).Mul(u256.NewUint(amount), q96) + xGnsAmountX96 = new(u256.Uint).Mul(xGnsAmountX96, u256.NewUint(1_000_000_000)) + + ratio := new(u256.Uint).Div(xGnsAmountX96, xGnsX96) + ratio = ratio.Mul(ratio, q96) + return ratio.Div(ratio, u256.NewUint(1_000_000_000)) +} diff --git a/contract/r/gnoswap/gov/staker/reward_calculation_test.gno b/contract/r/gnoswap/gov/staker/reward_calculation_test.gno new file mode 100644 index 000000000..29ce49e70 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/reward_calculation_test.gno @@ -0,0 +1,107 @@ +package staker + +import ( + "testing" + + "gno.land/p/demo/testutils" +) + +func TestRewardCalculation_1_1(t *testing.T) { + state := NewRewardState() + + current := 100 + state.AddStake(10, testutils.TestAddress("alice"), 100, uint64(current), make(map[string]uint64)) + + current += 100 + reward, _ := state.RemoveStake(testutils.TestAddress("alice"), 100, uint64(current), make(map[string]uint64)) + + if reward != 100+100 { + t.Errorf("expected reward %d, got %d", 100+100, reward) + } +} + +func TestRewardCalculation_1_2(t *testing.T) { + state := NewRewardState() + + current := 100 + state.AddStake(10, testutils.TestAddress("alice"), 100, uint64(current), make(map[string]uint64)) + + current += 100 + reward, _ := state.RemoveStake(testutils.TestAddress("alice"), 100, uint64(current), make(map[string]uint64)) + current -= int(reward) + + if reward != 100+100 { + t.Errorf("expected reward %d, got %d", 100+100, reward) + } + + current += 100 + state.AddStake(12, testutils.TestAddress("bob"), 100, uint64(current), make(map[string]uint64)) + + current += 100 + reward, _ = state.RemoveStake(testutils.TestAddress("bob"), 100, uint64(current), make(map[string]uint64)) + current -= int(reward) + if reward != 100+100 { + t.Errorf("expected reward %d, got %d", 100+100, reward) + } +} + +func TestRewardCalculation_1_3(t *testing.T) { + state := NewRewardState() + + // Alice takes 100 GNS + current := 100 + state.AddStake(10, testutils.TestAddress("alice"), 10, uint64(current), make(map[string]uint64)) + + // Alice takes 100 GNS + current += 100 + state.AddStake(11, testutils.TestAddress("bob"), 10, uint64(current), make(map[string]uint64)) + + // Alice takes 50 GNS, Bob takes 50 GNS + current += 100 + reward, _ := state.RemoveStake(testutils.TestAddress("alice"), 10, uint64(current), make(map[string]uint64)) + current -= int(reward) + if reward != 100+100+50 { + t.Errorf("expected reward %d, got %d", 100+100+50, reward) + } + + // Bob takes 100 GNS + current += 100 + reward, _ = state.RemoveStake(testutils.TestAddress("bob"), 10, uint64(current), make(map[string]uint64)) + current -= int(reward) + if reward != 100+50 { + t.Errorf("expected reward %d, got %d", 100+50, reward) + } +} + + +func TestRewardCalculation_1_4(t *testing.T) { + state := NewRewardState() + + // Alice takes 100 GNS + current := 100 + state.AddStake(10, testutils.TestAddress("alice"), 10, uint64(current), make(map[string]uint64)) + + // Alice takes 200GNS + current += 200 + state.AddStake(11, testutils.TestAddress("bob"), 30, uint64(current), make(map[string]uint64)) + + // Alice 25, Bob 75 + current += 100 + state.AddStake(12, testutils.TestAddress("charlie"), 10, uint64(current), make(map[string]uint64)) + + // Alice 20, Bob 60, Charlie 20 + current += 100 + reward, _ := state.RemoveStake(testutils.TestAddress("alice"), 10, uint64(current), make(map[string]uint64)) + current -= int(reward) + if reward != 100+200+25+20 { + t.Errorf("expected reward %d, got %d", 100+200+25+20, reward) + } + + // Bob 75, Charlie 25 + current += 100 + reward, _ = state.RemoveStake(testutils.TestAddress("bob"), 30, uint64(current), make(map[string]uint64)) + current -= int(reward) + if reward != 75+60+75 { + t.Errorf("expected reward %d, got %d", 75+60+75, reward) + } +} diff --git a/contract/r/gnoswap/gov/staker/staker.gno b/contract/r/gnoswap/gov/staker/staker.gno new file mode 100644 index 000000000..dc24b1d72 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/staker.gno @@ -0,0 +1,412 @@ +package staker + +import ( + "std" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/wugnot" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/gnoswap/v1/gov/xgns" + + en "gno.land/r/gnoswap/v1/emission" +) + +type lockedGNS struct { + amount uint64 + unlock uint64 + collected bool +} + +const TIMESTAMP_7_DAYS = uint64(604800) // 7 days in seconds + +var ( + addrLockedGns = avl.NewTree() // address -> []lockedGNS + lockedAmount = uint64(0) +) + +var minimumAmount = uint64(1_000_000) + +// Delegate delegates GNS tokens to a specified address. +// ref: https://docs.gnoswap.io/contracts/governance/staker.gno#delegate +func Delegate(to std.Address, amount uint64) { + if running { + cleanDelegationStatHistory() + } + + if amount == 0 { + panic(addDetailToError( + errInvalidAmount, + "delegation amount cannot be 0", + )) + } + + CollectReward() + + if !to.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("invalid address %s to delegate", to.String()), + )) + } + + if amount < minimumAmount { + panic(addDetailToError( + errLessThanMinimum, + ufmt.Sprintf("minimum amount to delegate is %d (requested:%d)", minimumAmount, amount), + )) + } + + if amount%minimumAmount != 0 { + panic(addDetailToError( + errInvalidAmount, + ufmt.Sprintf("amount must be multiple of %d", minimumAmount), + )) + } + + caller := std.PrevRealm().Addr() + gnsBalance := gns.BalanceOf(caller) + if gnsBalance < amount { + panic(addDetailToError( + errNotEnoughBalance, + ufmt.Sprintf("invalid GNS balance(%d) to delegate(%d)", gnsBalance, amount), + )) + } + + rewardState.AddStake(uint64(std.GetHeight()), caller, amount, getCurrentBalance(), getCurrentProtocolFeeBalance()) + + // GNS // caller -> GovStaker + gns.TransferFrom(caller, std.CurrentRealm().Addr(), amount) + + // actual delegate + delegate(to, amount) + + // xGNS mint to caller + xgns.Mint(caller, amount) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "Delegate", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "from", caller.String(), + "to", to.String(), + "amount", formatUint(amount), + ) +} + +// Redelegate redelegates xGNS from existing delegate to another. +// ref: https://docs.gnoswap.io/contracts/governance/staker.gno#redelegate +func Redelegate(from, to std.Address, amount uint64) { + if running { + cleanDelegationStatHistory() + } + + CollectReward() + + if !from.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("invalid from address %s to redelegate", to.String()), + )) + } + + if !to.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("invalid to address %s to redelegate", to.String()), + )) + } + + if amount < minimumAmount { + panic(addDetailToError( + errLessThanMinimum, + ufmt.Sprintf("minimum amount to redelegate is %d (requested:%d)", minimumAmount, amount), + )) + } + + if amount%minimumAmount != 0 { + panic(addDetailToError( + errInvalidAmount, + ufmt.Sprintf("amount must be multiple of %d", minimumAmount), + )) + } + + caller := std.PrevRealm().Addr() + + delegated := xgns.BalanceOf(caller) + if delegated < amount { + panic(addDetailToError( + errNotEnoughBalance, + ufmt.Sprintf("not enough xGNS delegated(%d) to redelegate(%d)", delegated, amount), + )) + } + + undelegate(to, amount) + delegate(to, amount) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "Redelegate", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "from", from.String(), + "to", to.String(), + "amount", formatUint(amount), + ) +} + +// Undelegate undelegates xGNS from the existing delegate. +// ref: https://docs.gnoswap.io/contracts/governance/staker.gno#undelegate +func Undelegate(from std.Address, amount uint64) { + if running { + cleanDelegationStatHistory() + } + + if !from.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("invalid address %s to undelegate", from.String()), + )) + } + + if amount < minimumAmount { + panic(addDetailToError( + errLessThanMinimum, + ufmt.Sprintf("minimum amount to undelegate is %d (requested:%d)", minimumAmount, amount), + )) + } + + if amount%minimumAmount != 0 { + panic(addDetailToError( + errInvalidAmount, + ufmt.Sprintf("amount must be multiple of %d", minimumAmount), + )) + } + + caller := std.PrevRealm().Addr() + delegated := xgns.BalanceOf(caller) + + if delegated < amount { + panic(addDetailToError( + errNotEnoughBalance, + ufmt.Sprintf("not enough xGNS delegated(%d) to undelegate(%d)", delegated, amount), + )) + } + + reward, protocolFeeRewards := rewardState.RemoveStake(caller, amount, getCurrentBalance(), getCurrentProtocolFeeBalance()) + + // burn equivalent amount of xGNS + xgns.Burn(caller, amount) + + gns.Transfer(caller, reward) + + for tokenPath, amount := range protocolFeeRewards { + transferProtocolFee(tokenPath, from, amount) + } + + // actual undelegate + undelegate(from, amount) + + // lock up + userLocked := lockedGNS{ + amount: amount, + unlock: uint64(time.Now().Unix()) + TIMESTAMP_7_DAYS, // after 7 days, call Collect() to receive GNS + } + + var lockedList []lockedGNS + if value, exists := addrLockedGns.Get(caller.String()); exists { + lockedList = value.([]lockedGNS) + } + + lockedList = append(lockedList, userLocked) + addrLockedGns.Set(caller.String(), lockedList) + lockedAmount += amount + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "Undelegate", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "from", from.String(), + "amount", formatUint(amount), + ) +} + +// CollectUndelegatedGns collects the amount of the undelegated GNS. +// ref: https://docs.gnoswap.io/contracts/governance/staker.gno#collectundelegatedgns +func CollectUndelegatedGns() uint64 { + common.IsHalted() + en.MintAndDistributeGns() + + caller := std.PrevRealm().Addr() + + value, exists := addrLockedGns.Get(caller.String()) + if !exists { + return 0 + } + + lockedList := value.([]lockedGNS) + if len(lockedList) == 0 { + return 0 + } + + prevAddr, prevPkgPath := getPrev() + collected := uint64(0) + currentTime := uint64(time.Now().Unix()) + + newLockedList := make([]lockedGNS, 0) + for _, locked := range lockedList { + if currentTime >= locked.unlock { // passed 7 days + // transfer GNS to caller + gns.Transfer(caller, locked.amount) + lockedAmount -= locked.amount + collected += locked.amount + } else { + newLockedList = append(newLockedList, locked) + } + } + + if len(newLockedList) > 0 { + addrLockedGns.Set(caller.String(), newLockedList) + } else { + _, removed := addrLockedGns.Remove(caller.String()) + if !removed { + panic("failed to remove locked GNS list") + } + } + + if collected > 0 { + std.Emit( + "CollectUndelegatedGns", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "from", consts.GOV_STAKER_ADDR.String(), + "to", caller.String(), + "collectedAmount", formatUint(collected), + ) + } + + return collected +} + +// CollectReward collects the rewards from the protocol fee contract based on the holdings of xGNS. +// ref: https://docs.gnoswap.io/contracts/governance/staker.gno#collectreward +func CollectReward() { + common.IsHalted() + en.MintAndDistributeGns() + + // calculateReward() + + prevAddr, prevPkgPath := getPrev() + caller := std.PrevRealm().Addr() + + reward, protocolFeeRewards := rewardState.Claim(caller, getCurrentBalance(), getCurrentProtocolFeeBalance()) + + // XXX (@notJoon): There could be cases where the reward pool is empty, In such case, + // it seems appropriate to return 0 and continue processing. + // + // This isn't necessarily an abnormal situation, particularly + // since it could be because rewards haven't occurred yet or + // have already been fully collected. + // + // still, this is a tangled with the policy issue, so should be discussed. + + if reward > 0 { + gns.Transfer(caller, reward) + std.Emit( + "CollectEmissionReward", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "from", consts.GOV_STAKER_ADDR.String(), + "to", caller.String(), + "emissionRewardAmount", formatUint(reward), + ) + } + + for tokenPath, amount := range protocolFeeRewards { + if tokenPath == consts.WUGNOT_PATH { + if amount > 0 { + wugnot.Withdraw(amount) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(consts.GOV_STAKER_ADDR, caller, std.Coins{{"ugnot", int64(amount)}}) + } + } else { + transferProtocolFee(tokenPath, caller, amount) + } + + std.Emit( + "CollectProtocolFeeReward", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "tokenPath", tokenPath, + "from", consts.GOV_STAKER_ADDR.String(), + "to", caller.String(), + "collectedAmount", formatUint(amount), + ) + } +} + +// CollectRewardFromLaunchPad collects the rewards from the protocol fee contract based on the holdings of xGNS in the launchpad contract. +// Only launchpad contract can call this function +// ref: https://docs.gnoswap.io/contracts/governance/staker.gno#collectrewardfromlaunchpad +func CollectRewardFromLaunchPad(to std.Address) { + assertCallerIsLaunchpad() + + common.IsHalted() + en.MintAndDistributeGns() + + prevAddr, prevPkgPath := getPrev() + + emissionReward, protocolFeeRewards := rewardState.Claim(to, getCurrentBalance(), getCurrentProtocolFeeBalance()) + if emissionReward > 0 { + gns.Transfer(to, emissionReward) + std.Emit( + "CollectEmissionFromLaunchPad", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "from", consts.GOV_STAKER_ADDR.String(), + "to", to.String(), + "emissionRewardAmount", formatUint(emissionReward), + ) + } + + for tokenPath, amount := range protocolFeeRewards { + transferProtocolFee(tokenPath, to, amount) + std.Emit( + "CollectProtocolFeeFromLaunchPad", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "tokenPath", tokenPath, + "from", consts.GOV_STAKER_ADDR.String(), + "to", to.String(), + "collectedAmount", formatUint(amount), + ) + } + +} + +func transferProtocolFee(tokenPath string, to std.Address, amount uint64) { + common.MustRegistered(tokenPath) + if !to.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("invalid address %s to transfer protocol fee", to.String()), + )) + } + if amount <= 0 { + panic(addDetailToError( + errInvalidAmount, + ufmt.Sprintf("invalid amount %d to transfer protocol fee", amount), + )) + } + + token := common.GetTokenTeller(tokenPath) + checkTransferError(token.Transfer(to, amount)) +} diff --git a/contract/r/gnoswap/gov/staker/staker_test.gno b/contract/r/gnoswap/gov/staker/staker_test.gno new file mode 100644 index 000000000..1eec23aed --- /dev/null +++ b/contract/r/gnoswap/gov/staker/staker_test.gno @@ -0,0 +1,621 @@ +package staker + +import ( + "std" + "strings" + "testing" + "time" + + "gno.land/p/demo/testutils" + "gno.land/r/demo/wugnot" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/gnoswap/v1/gov/xgns" +) + +var ( + adminRealm = std.NewUserRealm(consts.ADMIN) + userRealm = std.NewUserRealm(testutils.TestAddress("alice")) + user2Realm = std.NewUserRealm(testutils.TestAddress("bob")) + user3Realm = std.NewUserRealm(testutils.TestAddress("charlie")) + invalidAddr = testutils.TestAddress("invalid") + + ugnotDenom string = "ugnot" + ugnotPath string = "ugnot" + wugnotPath string = "gno.land/r/demo/wugnot" + + realmPrefix = "/gno.land/r/gnoswap/v1/gov/staker" +) + +func makeFakeAddress(name string) std.Address { + return testutils.TestAddress(name) +} + +func ugnotTransfer(t *testing.T, from, to std.Address, amount uint64) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(from)) + std.TestSetOrigSend(std.Coins{{ugnotDenom, int64(amount)}}, nil) + banker := std.GetBanker(std.BankerTypeRealmSend) + + banker.SendCoins(from, to, std.Coins{{ugnotDenom, int64(amount)}}) +} + +func ugnotBalanceOf(t *testing.T, addr std.Address) uint64 { + t.Helper() + + banker := std.GetBanker(std.BankerTypeRealmIssue) + coins := banker.GetCoins(addr) + if len(coins) == 0 { + return 0 + } + + return uint64(coins.AmountOf(ugnotDenom)) +} + +func ugnotMint(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.IssueCoin(addr, ugnotDenom, amount) + std.TestIssueCoins(addr, std.Coins{{ugnotDenom, int64(amount)}}) +} + +func ugnotBurn(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.RemoveCoin(addr, ugnotDenom, amount) +} + +func ugnotFaucet(t *testing.T, to std.Address, amount uint64) { + t.Helper() + faucetAddress := consts.ADMIN + std.TestSetOrigCaller(faucetAddress) + + if ugnotBalanceOf(t, faucetAddress) < amount { + ugnotMint(t, faucetAddress, ugnotPath, int64(amount)) + std.TestSetOrigSend(std.Coins{{ugnotPath, int64(amount)}}, nil) + } + ugnotTransfer(t, faucetAddress, to, amount) +} + +func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(addr)) + wugnotAddr := consts.WUGNOT_ADDR + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(addr, wugnotAddr, std.Coins{{ugnotDenom, int64(amount)}}) + wugnot.Deposit() +} + +func TestDelegate(t *testing.T) { + std.TestSetOrigCaller(consts.ADMIN) + SetRunning(true) + + std.TestSetRealm(userRealm) + { + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(100) + gns.MintGns(userRealm.Addr()) // 2M gns + + std.TestSetRealm(userRealm) + to := makeFakeAddress("validator_1") + amount := uint64(1_000_000) + + gns.Approve(consts.GOV_STAKER_ADDR, amount) + Delegate(to, amount) + + minted := xgns.BalanceOf(userRealm.Addr()) + if minted != amount { + t.Errorf("Delegate minted xGNS = %d, want %d", minted, amount) + } + // verify "delegate" avl state, or any event logs, etc. + } + + // 4) below minimum + { + to := makeFakeAddress("validator_2") + amount := uint64(999_999) + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic when delegating below minimumAmount") + } + }() + Delegate(to, amount) + } + + // 5) invalid to address + { + to := invalidAddr + amount := uint64(1_000_000) + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic with invalid 'to' address") + } + }() + Delegate(to, amount) + } + + // 6) not enough GNS balance user + { + // user GNS = 0 now (they delegated all away above) + to := makeFakeAddress("validator_3") + amount := uint64(2_000_000) + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic due to not enough GNS balance in user address") + } + }() + Delegate(to, amount) + } + + // 7) running = false => optional check + { + std.TestSetOrigCaller(consts.ADMIN) + SetRunning(false) + // delegate? + // depending on code logic, might skip cleanDelegationStatHistory() or do something else + // if there's logic that forbids delegation when not running, test that + // For now, assume it still works but doesn't clean. + // Restore running to true for subsequent tests. + std.TestSetOrigCaller(consts.ADMIN) + SetRunning(true) + } + + { + to := makeFakeAddress("validator_2") + amount := uint64(1999_999) + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic when delegating below minimumAmount") + } + }() + Delegate(to, amount) + } + + { + to := makeFakeAddress("validator_2") + amount := uint64(2000_000) + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic when delegating below minimumAmount") + } + }() + Delegate(to, amount) + } +} + +func TestDelegate_Boundary_Values(t *testing.T) { + tests := []struct { + name string + to std.Address + amount uint64 + expectPanic bool + panicMsg string + }{ + { + name: "delegate zero amount", + to: makeFakeAddress("validator_1"), + amount: 0, + expectPanic: true, + panicMsg: "delegation amount cannot be 0", + }, + { + name: "delegate max uint64", + to: makeFakeAddress("validator_1"), + amount: ^uint64(0), // max uint64 + expectPanic: true, + panicMsg: "[GNOSWAP-GOV_STAKER-004] invalid amount || amount must be multiple of 1000000", + }, + { + name: "delegate near max uint64", + to: makeFakeAddress("validator_1"), + amount: ^uint64(0) - 1000, + expectPanic: true, + panicMsg: "[GNOSWAP-GOV_STAKER-004] invalid amount || amount must be multiple of 1000000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetOrigCaller(consts.ADMIN) + SetRunning(true) + std.TestSetRealm(userRealm) + + if tt.expectPanic { + defer func() { + r := recover() + if r == nil { + t.Errorf("Expected panic but got none") + } + if tt.panicMsg != "" && r != nil { + if msg, ok := r.(string); !ok || !strings.Contains(msg, tt.panicMsg) { + t.Errorf("Expected panic message containing '%s', got '%v'", tt.panicMsg, r) + } + } + }() + } + + Delegate(tt.to, tt.amount) + }) + } +} + +/* +func TestEmptyRewardPool(t *testing.T) { + tests := []struct { + name string + setupFn func() + expectPanic bool + expectPanicMsg string + checkFn func(t *testing.T) + }{ + { + name: "collect with empty reward pool", + setupFn: func() { + //userEmissionReward.Remove(userRealm.Addr().String()) + }, + expectPanic: false, + checkFn: func(t *testing.T) { + if v, exists := userEmissionReward.Get(userRealm.Addr().String()); exists { + t.Errorf("Expected userEmissionReward to be removed, but got %d", v.(uint64)) + } + }, + }, + { + name: "collect with empty protocol fee rewards", + setupFn: func() { + //userProtocolFeeReward.Remove(userRealm.Addr().String()) + }, + expectPanic: false, + checkFn: func(t *testing.T) { + if v, exists := userProtocolFeeReward.Get(userRealm.Addr().String()); exists { + t.Errorf("Expected userProtocolFeeReward to be removed, but got %d", v.(uint64)) + } + }, + }, + { + name: "collect with empty staker GNS balance", + setupFn: func() { + std.TestSetRealm(std.NewCodeRealm(consts.GOV_STAKER_PATH)) + userEmissionReward.Set(userRealm.Addr().String(), uint64(1000)) + }, + expectPanic: true, + expectPanicMsg: "not enough GNS", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == "collect with empty staker GNS balance" { + t.Skip("run this test separately") + } + std.TestSetRealm(userRealm) + + if tt.setupFn != nil { + tt.setupFn() + } + + if tt.expectPanic { + defer func() { + r := recover() + if r == nil { + t.Errorf("Expected panic but got none") + } + if tt.expectPanicMsg != "" && r != nil { + if msg, ok := r.(string); !ok || !strings.Contains(msg, tt.expectPanicMsg) { + t.Errorf("Expected panic message containing '%s', got '%v'", tt.expectPanicMsg, r) + } + } + }() + } + + CollectReward() + + if tt.checkFn != nil { + tt.checkFn(t) + } + }) + } +} +*/ +/* +func TestProtocolFee(t *testing.T) { + tests := []struct { + name string + setupFn func() + expectPanic bool + panicMsg string + checkFn func(t *testing.T) + }{ + { + name: "collect max uint64 protocol fee", + setupFn: func() { + //tree := avl.NewTree() + //tree.Set(consts.WUGNOT_PATH, uint64(^uint64(0))) + //userProtocolFeeReward.Set(userRealm.Addr().String(), tree) + currentProtocolFeeBalance[consts.WUGNOT_PATH] = ^uint64(0) + }, + expectPanic: true, + panicMsg: "insufficient balance", + }, + { + name: "collect multiple empty token balances", + setupFn: func() { + //tree := avl.NewTree() + //tree.Set(consts.WUGNOT_PATH, uint64(1000)) + //tree.Set("some/other/token", uint64(2000)) + //userProtocolFeeReward.Set(userRealm.Addr().String(), tree) + currentProtocolFeeBalance[consts.WUGNOT_PATH] = 1000 + currentProtocolFeeBalance["some/other/token"] = 2000 + }, + expectPanic: true, + panicMsg: "insufficient balance", + }, + { + name: "collect with zero amounts", + setupFn: func() { + //tree := avl.NewTree() + //tree.Set(consts.WUGNOT_PATH, uint64(0)) + //tree.Set("some/other/token", uint64(0)) + //userProtocolFeeReward.Set(userRealm.Addr().String(), tree) + currentProtocolFeeBalance[consts.WUGNOT_PATH] = 0 + currentProtocolFeeBalance["some/other/token"] = 0 + }, + expectPanic: false, + checkFn: func(t *testing.T) { + val, exists := userProtocolFeeReward.Get(userRealm.Addr().String()) + if !exists { + t.Errorf("Expected protocol fee reward entry to exist") + return + } + tree := val.(*avl.Tree) + wugnotAmt, _ := tree.Get(consts.WUGNOT_PATH) + if wugnotAmt.(uint64) != 0 { + t.Errorf("Expected WUGNOT amount to be 0, got %d", wugnotAmt) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetRealm(userRealm) + + if tt.setupFn != nil { + tt.setupFn() + } + + if tt.expectPanic { + defer func() { + r := recover() + if r == nil { + t.Errorf("Expected panic but got none") + } + if tt.panicMsg != "" && r != nil { + if msg, ok := r.(string); !ok || !strings.Contains(strings.ToLower(msg), strings.ToLower(tt.panicMsg)) { + t.Errorf("Expected panic message containing '%s', got '%v'", tt.panicMsg, r) + } + } + }() + } + + CollectReward() + + if tt.checkFn != nil { + tt.checkFn(t) + } + }) + } +} +*/ +func TestRedelegate(t *testing.T) { + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(100) + gns.MintGns(userRealm.Addr()) // 2M gns + + // user has xGNS from previous test (some minted) + // try re-delegating + std.TestSetRealm(userRealm) + from := userRealm.Addr() + to := makeFakeAddress("validator_1") + amount := uint64(1_000_000) + + std.TestSetOrigCaller(userRealm.Addr()) + gns.Approve(consts.GOV_STAKER_ADDR, amount) + Delegate(to, amount) + + std.TestSetOrigCaller(userRealm.Addr()) + gns.Approve(consts.GOV_STAKER_ADDR, amount) + Redelegate(from, to, amount) + + // check data: user xGNS must remain the same, but from-> to shift in staker structure? + + // not enough xGNS => panic + { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic if user tries to re-delegate more than xGNS balance") + } + }() + Redelegate(from, to, 999_999_999) + } +} + +func TestUndelegate(t *testing.T) { + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(101) + gns.MintGns(user3Realm.Addr()) // 2M gns + + std.TestSetRealm(user3Realm) + to := makeFakeAddress("validator_1") + amount := uint64(1_000_000) + + std.TestSetOrigCaller(user3Realm.Addr()) + gns.Approve(consts.GOV_STAKER_ADDR, amount) + Delegate(to, amount) + Undelegate(to, amount) + + if xgns.BalanceOf(user3Realm.Addr()) == 1_000_000 { + t.Errorf("Expected user xGNS to be 0 after undelegating 1_000_000, got %d", + xgns.BalanceOf(user3Realm.Addr())) + } + + lockedList, exist := addrLockedGns.Get(user3Realm.Addr().String()) + if !exist { + t.Errorf("Expected lockedGNS to be created after Undelegate") + } + locked := lockedList.([]lockedGNS) + if len(locked) == 0 { + t.Errorf("Expected at least 1 lockedGNS after Undelegate") + } + if locked[0].amount != 1_000_000 { + t.Errorf("LockedGNS amount mismatch, got %d, want 1_000_000", locked[0].amount) + } + // check lockedAmount incremented? + if lockedAmount != 1_000_000 { + t.Errorf("lockedAmount expected 1_000_000, got %d", lockedAmount) + } + + // below minimum => panic + { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic when undelegating below minimum") + } + }() + Undelegate(to, 999_999) // or 999_990, etc. if min = 1_000_000 + } +} + +func TestCollectUndelegatedGns(t *testing.T) { + std.TestSetRealm(userRealm) + + // 1) no locked => expect 0 + addrLockedGns.Remove(userRealm.Addr().String()) // ensure no locked + collected := CollectUndelegatedGns() + if collected != 0 { + t.Errorf("Expected 0 when no locked gns, got %d", collected) + } + + // 2) add locked but time not passed => 0 + now := uint64(time.Now().Unix()) + locked := lockedGNS{ + amount: 100_000, + unlock: now + TIMESTAMP_7_DAYS, + } + addrLockedGns.Set(userRealm.Addr().String(), []lockedGNS{locked}) + lockedAmount = 100_000 + + collected = CollectUndelegatedGns() + if collected != 0 { + t.Errorf("Expected 0 if 7days not passed, got %d", collected) + } + // verify still in locked + lockedList, exist := addrLockedGns.Get(userRealm.Addr().String()) + if !exist { + t.Errorf("Expected lockedGNS to remain after CollectUndelegatedGns") + } + lkList := lockedList.([]lockedGNS) + if len(lkList) != 1 { + t.Errorf("Locked list should remain, but length = %d", len(lkList)) + } + + // 3) set time => unlocked + locked.unlock = uint64(time.Now().Unix()) - 1 // forcibly make it past time + addrLockedGns.Set(userRealm.Addr().String(), []lockedGNS{locked}) + + collected = CollectUndelegatedGns() + if collected != 100_000 { + t.Errorf("Expected 100_000 collected, got %d", collected) + } + // check locked removed from the tree + if _, exists := addrLockedGns.Get(userRealm.Addr().String()); exists { + t.Errorf("Expected addrLockedGns key to be removed if empty after collecting all locked gns") + } + if lockedAmount != 0 { + t.Errorf("lockedAmount should have been decremented to 0, got %d", lockedAmount) + } +} + +func TestCollectReward(t *testing.T) { + t.Skip("minting ugnot is not working") + + std.TestSetRealm(user2Realm) + { + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(100) + gns.MintGns(consts.GOV_STAKER_ADDR) + + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + ugnotFaucet(t, consts.GOV_STAKER_ADDR, 1_000_000) + ugnotDeposit(t, consts.GOV_STAKER_ADDR, 1_000_000) + ugnotFaucet(t, derivePkgAddr(wugnotPath), 1_000_000) + ugnotFaucet(t, user2Realm.Addr(), 1_000_000) + } + + std.TestSetRealm(user2Realm) + user := user2Realm.Addr().String() + + rewardState.AddStake(uint64(std.GetHeight()), std.Address(user), 10, 0, make(map[string]uint64)) + + // set a fake emission reward + //userEmissionReward.Set(user, uint64(50_000)) + currentGNSBalance = 50_000 + // set a fake protocol fee reward + //tree := avl.NewTree() + //tree.Set(consts.WUGNOT_PATH, uint64(10_000)) + //userProtocolFeeReward.Set(user, tree) + currentProtocolFeeBalance[consts.WUGNOT_PATH] = 10_000 + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(100) + gns.MintGns(user2Realm.Addr()) + + std.TestSetRealm(user2Realm) + std.TestSkipHeights(100) + + // call CollectReward + std.TestSetOrigCaller(user2Realm.Addr()) + CollectReward() + + // expect user emissionReward = 0 + claimableAmount, protocolFeeClaimable := rewardState.CalculateReward(std.Address(user)) + if claimableAmount != 0 { + t.Errorf("Expected userEmissionReward to be 0 after collect, got %d", claimableAmount) + } + if len(protocolFeeClaimable) != 1 { + t.Errorf("Expected size=1, but let's check the actual value = 0?") + } + for _, amount := range protocolFeeClaimable { + if amount != 0 { + t.Errorf("Expected protocolFeeClaimable to be 0 after collect, got %d", amount) + } + } + + // If GOV_STAKER_ADDR had less GNS than 50_000 => we expect panic + // can test in separate subcase +} + +func TestCollectRewardFromLaunchPad(t *testing.T) { + // set realm to LAUNCHPAD_ADDR? + // or we do a quick scenario: if current caller != LAUNCHPAD_ADDR => panic + // => check it with a user realm to ensure panic + std.TestSetRealm(userRealm) + { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic, because caller is not launchpad") + } + }() + CollectRewardFromLaunchPad(userRealm.Addr()) + } + + // Then set realm to a custom "launchpadRealm" whose .Addr() == consts.LAUNCHPAD_ADDR + // => call normal and see if distribution works +} diff --git a/contract/r/gnoswap/gov/staker/util.gno b/contract/r/gnoswap/gov/staker/util.gno new file mode 100644 index 000000000..1e9fb2235 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/util.gno @@ -0,0 +1,127 @@ +package staker + +import ( + b64 "encoding/base64" + "std" + "strconv" + + "gno.land/p/demo/avl" + "gno.land/p/demo/json" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +// marshal data to json string +func marshal(data *json.Node) string { + b, err := json.Marshal(data) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +// b64Encode encodes data to base64 string +func b64Encode(data string) string { + return string(b64.StdEncoding.EncodeToString([]byte(data))) +} + +// getPrevRealm returns the previous realm's package path +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// isUserCall returns true if the previous realm is a user +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// getPrev returns the previous realm's address and package path +func getPrev() (string, string) { + prev := getPrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// derivePkgAddr derives the Realm address from it's pkgPath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) +} + +// formatUint formats a uint64 to a string +func formatUint(v uint64) string { + return strconv.FormatUint(v, 10) +} + +// formatInt formats an int64 to a string +func formatInt(v int64) string { + return strconv.FormatInt(v, 10) +} + +// assertCallerIsAdmin panics if the caller is not an admin +func assertCallerIsAdmin() { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } +} + +// assertCallerIsLaunchpad panics if the caller is not the launchpad +func assertCallerIsLaunchpad() { + caller := std.PrevRealm().Addr() + if caller != consts.LAUNCHPAD_ADDR { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only launchpad can call CollectRewardFromLaunchPad(), called from %s", caller.String()), + )) + } +} + +// getUint64FromTree returns the uint64 value from the tree +func getUint64FromTree(tree *avl.Tree, key string) uint64 { + value, exists := tree.Get(key) + if !exists { + return 0 + } + + return value.(uint64) +} + +// updateUint64InTree updates the uint64 value in the tree +func updateUint64InTree(tree *avl.Tree, key string, delta uint64, add bool) uint64 { + current := getUint64FromTree(tree, key) + var newValue uint64 + if add { + newValue = current + delta + } else { + if current < delta { + panic(addDetailToError( + errNotEnoughBalance, + ufmt.Sprintf("not enough balance: current(%d) < requested(%d)", current, delta), + )) + } + newValue = current - delta + } + + tree.Set(key, newValue) + + return newValue +} + +// getOrCreateInnerTree returns the inner tree for the given key +func getOrCreateInnerTree(tree *avl.Tree, key string) *avl.Tree { + value, exists := tree.Get(key) + if !exists { + innerTree := avl.NewTree() + tree.Set(key, innerTree) + return innerTree + } + + return value.(*avl.Tree) +} diff --git a/contract/r/gnoswap/gov/staker/util_test.gno b/contract/r/gnoswap/gov/staker/util_test.gno new file mode 100644 index 000000000..6bc2a09f6 --- /dev/null +++ b/contract/r/gnoswap/gov/staker/util_test.gno @@ -0,0 +1,241 @@ +package staker + +import ( + b64 "encoding/base64" + "math" + "std" + "testing" + "time" + + "gno.land/p/demo/json" + "gno.land/p/demo/uassert" + + "gno.land/p/demo/avl" + "gno.land/p/gnoswap/consts" +) + +func TestMarshal(t *testing.T) { + fakeNode := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + got := marshal(fakeNode) + if !(len(got) > 0) { + t.Errorf("TestMarshal() failed, got empty string") + } + t.Run("TestMarshal", func(t *testing.T) { + uassert.Equal(t, "{\"height\":123,\"timestamp\":1234567890}", got) + }) +} + +func TestB64Encode(t *testing.T) { + input := "Hello World" + want := b64.StdEncoding.EncodeToString([]byte(input)) + got := b64Encode(input) + if got != want { + t.Errorf("TestB64Encode() = %s; want %s", got, want) + } +} + +func TestGetPrevRealm(t *testing.T) { + got := getPrevRealm() + if got.PkgPath() != "" { + t.Errorf("TestPrevRealm() got package path") + } +} + +func TestIsUserCall(t *testing.T) { + tests := []struct { + name string + action func() bool + expected bool + }{ + { + name: "called from user", + action: func() bool { + userRealm := std.NewUserRealm(std.Address("user")) + std.TestSetRealm(userRealm) + return isUserCall() + }, + expected: true, + }, + { + name: "called from realm", + action: func() bool { + fromRealm := std.NewCodeRealm("gno.land/r/realm") + std.TestSetRealm(fromRealm) + return isUserCall() + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uassert.Equal(t, tt.expected, tt.action()) + }) + } +} + +func TestGetPrev(t *testing.T) { + tests := []struct { + name string + action func() (string, string) + expectedAddr string + expectedPkgPath string + }{ + { + name: "user call", + action: func() (string, string) { + userRealm := std.NewUserRealm(std.Address("user")) + std.TestSetRealm(userRealm) + return getPrev() + }, + expectedAddr: "user", + expectedPkgPath: "", + }, + { + name: "code call", + action: func() (string, string) { + codeRealm := std.NewCodeRealm("gno.land/r/demo/realm") + std.TestSetRealm(codeRealm) + return getPrev() + }, + expectedAddr: std.DerivePkgAddr("gno.land/r/demo/realm").String(), + expectedPkgPath: "gno.land/r/demo/realm", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, pkgPath := tt.action() + uassert.Equal(t, tt.expectedAddr, addr) + uassert.Equal(t, tt.expectedPkgPath, pkgPath) + }) + } +} + +func TestFormatUint(t *testing.T) { + tests := []struct { + input uint64 + expected string + }{ + {0, "0"}, + {12345, "12345"}, + {math.MaxUint64, "18446744073709551615"}, + } + + for _, tt := range tests { + result := formatUint(tt.input) + if result != tt.expected { + t.Errorf("formatUint(%d) = %s; want %s", tt.input, result, tt.expected) + } + } +} + +func TestFormatInt(t *testing.T) { + tests := []struct { + input int64 + expected string + }{ + {0, "0"}, + {-12345, "-12345"}, + {math.MaxInt64, "9223372036854775807"}, + {math.MinInt64, "-9223372036854775808"}, + } + + for _, tt := range tests { + result := formatInt(tt.input) + if result != tt.expected { + t.Errorf("formatInt(%d) = %s; want %s", tt.input, result, tt.expected) + } + } +} + +func TestAssertCallerIsAdmin(t *testing.T) { + adminAddr := consts.ADMIN + nonAdminAddr := std.Address("user456") + + std.TestSetOrigCaller(adminAddr) + assertCallerIsAdmin() + + // Test non-admin + std.TestSetOrigCaller(nonAdminAddr) + defer func() { + if r := recover(); r == nil { + t.Errorf("assertCallerIsAdmin() did not panic for non-admin") + } + }() + assertCallerIsAdmin() +} + +func TestAssertCallerIsLaunchpad(t *testing.T) { + launchpadAddr := consts.LAUNCHPAD_ADDR + otherAddr := std.Address("other123") + + // Test valid launchpad + std.TestSetOrigCaller(launchpadAddr) + assertCallerIsLaunchpad() + + // Test invalid caller + std.TestSetOrigCaller(otherAddr) + defer func() { + if r := recover(); r == nil { + t.Errorf("assertCallerIsLaunchpad() did not panic for non-launchpad caller") + } + }() + assertCallerIsLaunchpad() +} + +func TestGetUint64FromTree(t *testing.T) { + tree := avl.NewTree() + tree.Set("key1", uint64(100)) + + if value := getUint64FromTree(tree, "key1"); value != 100 { + t.Errorf("getUint64FromTree(tree, 'key1') = %d; want 100", value) + } + + if value := getUint64FromTree(tree, "key2"); value != 0 { + t.Errorf("getUint64FromTree(tree, 'key2') = %d; want 0", value) + } +} + +func TestUpdateUint64InTree(t *testing.T) { + tree := avl.NewTree() + tree.Set("key1", uint64(100)) + + // Add value + if newValue := updateUint64InTree(tree, "key1", 50, true); newValue != 150 { + t.Errorf("updateUint64InTree add failed; got %d, want 150", newValue) + } + + // Subtract value + if newValue := updateUint64InTree(tree, "key1", 50, false); newValue != 100 { + t.Errorf("updateUint64InTree subtract failed; got %d, want 100", newValue) + } + + // Attempt to subtract too much + defer func() { + if r := recover(); r == nil { + t.Errorf("updateUint64InTree did not panic on insufficient balance") + } + }() + updateUint64InTree(tree, "key1", 150, false) +} + +func TestGetOrCreateInnerTree(t *testing.T) { + tree := avl.NewTree() + + // Test creation of new inner tree + innerTree := getOrCreateInnerTree(tree, "key1") + if innerTree == nil { + t.Errorf("getOrCreateInnerTree did not create a new inner tree") + } + + // Test retrieval of existing inner tree + retrievedTree := getOrCreateInnerTree(tree, "key1") + if retrievedTree != innerTree { + t.Errorf("getOrCreateInnerTree did not return the existing inner tree") + } +} diff --git a/contract/r/gnoswap/gov/xgns/errors.gno b/contract/r/gnoswap/gov/xgns/errors.gno new file mode 100644 index 000000000..d9c245655 --- /dev/null +++ b/contract/r/gnoswap/gov/xgns/errors.gno @@ -0,0 +1,22 @@ +package xgns + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-XGNS-001] caller has no permission") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} + +func checkErr(err error) { + if err != nil { + panic(err.Error()) + } +} diff --git a/contract/r/gnoswap/gov/xgns/gno.mod b/contract/r/gnoswap/gov/xgns/gno.mod new file mode 100644 index 000000000..ab5d422a3 --- /dev/null +++ b/contract/r/gnoswap/gov/xgns/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/gov/xgns diff --git a/contract/r/gnoswap/gov/xgns/xgns.gno b/contract/r/gnoswap/gov/xgns/xgns.gno new file mode 100644 index 000000000..3ea385ac6 --- /dev/null +++ b/contract/r/gnoswap/gov/xgns/xgns.gno @@ -0,0 +1,188 @@ +package xgns + +import ( + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +var ( + token *grc20.Token + ledger *grc20.PrivateLedger + admin *ownable.Ownable +) + +func init() { + admin = ownable.NewWithAddress(std.DerivePkgAddr(consts.GOV_STAKER_PATH)) + + token, ledger = grc20.NewToken("XGNS", "xGNS", 6) +} + +func TotalSupply() uint64 { return token.TotalSupply() } + +// VotingSupply calculates the total supply of tokens eligible for voting. +// +// This function determines the total voting supply by subtracting the amount +// of tokens held by the launchpad contract from the total minted token supply. +// Tokens held by the launchpad contract do not participate in voting. +// +// Returns: +// - uint64: The total supply of tokens available for voting. +// +// Notes: +// - `TotalSupply`: Represents the total amount of xGNS tokens minted. +// - `BalanceOf(consts.LAUNCHPAD_ADDR)`: Retrieves the amount of xGNS tokens held by the launchpad contract. +func VotingSupply() uint64 { + total := token.TotalSupply() // this is entire amount of xGNS minted + + // this is amount of xGNS held by launchpad + // this xGNS doesn't participate in voting + launchpad := token.BalanceOf(consts.LAUNCHPAD_ADDR) + return total - launchpad +} + +// BalanceOf retrieves the token balance of a specified address. +// +// This function resolves the provided address or name and queries the token balance +// associated with the resolved address. +// +// Parameters: +// - owner: The address or name of the user whose balance is being queried. +// +// Returns: +// - uint64: The current token balance of the specified address. +func BalanceOf(owner std.Address) uint64 { + return token.BalanceOf(owner) +} + +// xGNS is non-transferable +// Therefore it doesn't have transfer and transferFrom functions +func Render(path string) string { + parts := strings.Split(path, "/") + c := len(parts) + + switch { + case path == "": + return token.RenderHome() + case c == 2 && parts[0] == "balance": + balance := token.BalanceOf(std.Address(parts[1])) + return ufmt.Sprintf("%d\n", balance) + default: + return "404\n" + } +} + +// Mint increases the balance of a specified address by a given amount. +// +// This function is restricted to be called only by specific authorized contracts: +// - Governance staker contract +// +// If the caller is not one of these contracts, the function will panic with an error. +// +// Parameters: +// - to: The address or name of the user whose balance will be increased. +// - amount: The amount of tokens to be minted. +// +// Errors: +// - Panics if the caller is unauthorized. +// - Propagates any error from the ledger.Mint function. +func Mint(to std.Address, amount uint64) { + common.IsHalted() + + // only (gov staker) or (launchpad) contract can call Mint + caller := std.PrevRealm().Addr() + if caller != consts.GOV_STAKER_ADDR && caller != consts.LAUNCHPAD_ADDR { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only gov/staker(%s) or launchpad(%s) contract can call Mint, called from %s", consts.GOV_STAKER_ADDR.String(), consts.LAUNCHPAD_ADDR.String(), caller.String()), + )) + } + + checkErr(ledger.Mint(to, amount)) +} + +// MintByLaunchPad increases the balance of a specified address by a given amount. +// +// This function is restricted to be called only by specific authorized contracts: +// - Launchpad contract +// +// If the caller is not one of these contracts, the function will panic with an error. +// +// Parameters: +// - to: The address or name of the user whose balance will be increased. +// - amount: The amount of tokens to be minted. +// +// Errors: +// - Panics if the caller is unauthorized. +// - Propagates any error from the ledger.Mint function. +func MintByLaunchPad(to std.Address, amount uint64) { + common.IsHalted() + + caller := std.PrevRealm().Addr() + if caller != consts.LAUNCHPAD_ADDR { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only launchpad(%s) contract can call MintByLaunchPad, called from %s", consts.LAUNCHPAD_ADDR.String(), caller.String()), + )) + } + + mint(to, amount) +} + +// mint increases the balance of a specified address by a given amount. +func mint(to std.Address, amount uint64) { + checkErr(ledger.Mint(to, amount)) +} + +// Burn reduces the balance of a specified address by a given amount. +// +// This function is restricted to be called only by specific authorized contracts: +// - Governance staker contract +// - Launchpad contract +// +// If the caller is not one of these contracts, the function will panic with an error. +// +// Parameters: +// - from: The address or name of the user whose balance will be reduced. +// - amount: The amount of tokens to be burned. +// +// Errors: +// - Panics if the caller is unauthorized. +// - Propagates any error from the ledger.Burn function. +func Burn(from std.Address, amount uint64) { + common.IsHalted() + + // only (gov staker) or (launchpad) contract can call Mint + caller := std.PrevRealm().Addr() + if !(caller == consts.GOV_STAKER_ADDR || caller == consts.LAUNCHPAD_ADDR) { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only gov/staker(%s) or launchpad(%s) contract can call Burn, called from %s", consts.GOV_STAKER_ADDR.String(), consts.LAUNCHPAD_ADDR.String(), caller.String()), + )) + } + + burn(from, amount) +} + +func BurnByLaunchPad(from std.Address, amount uint64) { + common.IsHalted() + + caller := std.PrevRealm().Addr() + if caller != consts.LAUNCHPAD_ADDR { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only launchpad(%s) contract can call BurnByLaunchPad, called from %s", consts.LAUNCHPAD_ADDR.String(), caller.String()), + )) + } + burn(from, amount) +} + +func burn(from std.Address, amount uint64) { + checkErr(ledger.Burn(from, amount)) +} diff --git a/contract/r/gnoswap/gov/xgns/xgns_test.gno b/contract/r/gnoswap/gov/xgns/xgns_test.gno new file mode 100644 index 000000000..6949f449c --- /dev/null +++ b/contract/r/gnoswap/gov/xgns/xgns_test.gno @@ -0,0 +1,75 @@ +package xgns + +import ( + "std" + "testing" + + "gno.land/p/gnoswap/consts" + + "gno.land/p/demo/uassert" +) + +func TestTotalSupply(t *testing.T) { + expectedSupply := uint64(0) + actualSupply := TotalSupply() + if actualSupply != expectedSupply { + t.Errorf("TotalSupply() failed. Expected %d, got %d", expectedSupply, actualSupply) + } +} + +func TestVotingSupply(t *testing.T) { + initialSupply := uint64(1000) + launchpadBalance := uint64(200) + + std.TestSetRealm(std.NewCodeRealm(consts.GOV_STAKER_PATH)) + Mint(consts.GOV_STAKER_ADDR, initialSupply-launchpadBalance) + + std.TestSetRealm(std.NewCodeRealm(consts.LAUNCHPAD_PATH)) + MintByLaunchPad(consts.LAUNCHPAD_ADDR, launchpadBalance) + + expectedVotingSupply := initialSupply - launchpadBalance + actualVotingSupply := VotingSupply() + if actualVotingSupply != expectedVotingSupply { + t.Errorf("VotingSupply() failed. Expected %d, got %d", expectedVotingSupply, actualVotingSupply) + } + + expectedBalance := launchpadBalance + actualBalance := BalanceOf(consts.LAUNCHPAD_ADDR) + if actualBalance != expectedBalance { + t.Errorf("BalanceOf() failed. Expected %d, got %d", expectedBalance, actualBalance) + } +} + +func TestMintFail(t *testing.T) { + amount := uint64(100) + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + uassert.PanicsWithMessage(t, "[GNOSWAP-XGNS-001] caller has no permission || only gov/staker(g17e3ykyqk9jmqe2y9wxe9zhep3p7cw56davjqwa) or launchpad(g122mau2lp2rc0scs8d27pkkuys4w54mdy2tuer3) contract can call Mint, called from g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", func() { + Mint(consts.GOV_STAKER_ADDR, amount) + }) + uassert.PanicsWithMessage(t, "[GNOSWAP-XGNS-001] caller has no permission || only launchpad(g122mau2lp2rc0scs8d27pkkuys4w54mdy2tuer3) contract can call MintByLaunchPad, called from g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", func() { + MintByLaunchPad(consts.GOV_STAKER_ADDR, amount) + }) +} + +func TestBurn(t *testing.T) { + burnAmount := uint64(200) + + std.TestSetRealm(std.NewCodeRealm(consts.LAUNCHPAD_PATH)) + BurnByLaunchPad(consts.LAUNCHPAD_ADDR, burnAmount) + expectedBalance := uint64(0) + actualBalance := BalanceOf(consts.LAUNCHPAD_ADDR) + if actualBalance != expectedBalance { + t.Errorf("Burn() failed. Expected %d, got %d", expectedBalance, actualBalance) + } +} + +func TestBurnFail(t *testing.T) { + amount := uint64(100) + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + uassert.PanicsWithMessage(t, "[GNOSWAP-XGNS-001] caller has no permission || only gov/staker(g17e3ykyqk9jmqe2y9wxe9zhep3p7cw56davjqwa) or launchpad(g122mau2lp2rc0scs8d27pkkuys4w54mdy2tuer3) contract can call Burn, called from g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", func() { + Burn(consts.GOV_STAKER_ADDR, amount) + }) + uassert.PanicsWithMessage(t, "[GNOSWAP-XGNS-001] caller has no permission || only launchpad(g122mau2lp2rc0scs8d27pkkuys4w54mdy2tuer3) contract can call BurnByLaunchPad, called from g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", func() { + BurnByLaunchPad(consts.GOV_STAKER_ADDR, amount) + }) +} diff --git a/contract/r/gnoswap/launchpad/_helper_test.gno b/contract/r/gnoswap/launchpad/_helper_test.gno new file mode 100644 index 000000000..cf0045245 --- /dev/null +++ b/contract/r/gnoswap/launchpad/_helper_test.gno @@ -0,0 +1,138 @@ +package launchpad + +import ( + "std" + "testing" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/testutils" + "gno.land/p/gnoswap/consts" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) +) + +// MockXGNSToken implements basic functionality for testing xgns token +type MockXGNSToken struct { + token *grc20.Token + ledger *grc20.PrivateLedger +} + +func NewMockXGNSToken() *MockXGNSToken { + token, ledger := grc20.NewToken("XGNS", "xGNS", 6) + return &MockXGNSToken{ + token: token, + ledger: ledger, + } +} + +func (m *MockXGNSToken) TotalSupply() uint64 { + return m.token.TotalSupply() +} + +func (m *MockXGNSToken) VotingSupply() uint64 { + total := m.token.TotalSupply() + launchpad := m.token.BalanceOf(consts.LAUNCHPAD_ADDR) + return total - launchpad +} + +func (m *MockXGNSToken) BalanceOf(owner std.Address) uint64 { + ownerAddr := owner + return m.token.BalanceOf(ownerAddr) +} + +func (m *MockXGNSToken) Mint(to std.Address, amount uint64) error { + toAddr := to + return m.ledger.Mint(toAddr, amount) +} + +func (m *MockXGNSToken) Burn(from std.Address, amount uint64) error { + fromAddr := from + return m.ledger.Burn(fromAddr, amount) +} + +// Helper functions for tests +func setupXGNSTest(t *testing.T) (*MockXGNSToken, std.Address) { + t.Helper() + mockToken := NewMockXGNSToken() + testAddr := testutils.TestAddress("test") + + // Set up mock token with initial state + return mockToken, testAddr +} + +func xgnsMint(t *testing.T, token *MockXGNSToken, to std.Address, amount uint64) { + t.Helper() + // Set realm to gov/staker or launchpad to have permission + std.TestSetRealm(std.NewCodeRealm(consts.LAUNCHPAD_PATH)) + err := token.Mint(to, amount) + if err != nil { + t.Fatalf("Failed to mint XGNS: %v", err) + } +} + +func xgnsBurn(t *testing.T, token *MockXGNSToken, from std.Address, amount uint64) { + t.Helper() + // Set realm to gov/staker or launchpad to have permission + std.TestSetRealm(std.NewCodeRealm(consts.LAUNCHPAD_PATH)) + err := token.Burn(from, amount) + if err != nil { + t.Fatalf("Failed to burn XGNS: %v", err) + } +} + +func xgnsCheckBalance(t *testing.T, token *MockXGNSToken, owner std.Address, expected uint64) { + t.Helper() + balance := token.BalanceOf(owner) + if balance != expected { + t.Errorf("XGNS balance mismatch for %s: got %d, want %d", owner, balance, expected) + } +} + +func xgnsCheckTotalSupply(t *testing.T, token *MockXGNSToken, expected uint64) { + t.Helper() + supply := token.TotalSupply() + if supply != expected { + t.Errorf("XGNS total supply mismatch: got %d, want %d", supply, expected) + } +} + +func xgnsCheckVotingSupply(t *testing.T, token *MockXGNSToken, expected uint64) { + t.Helper() + supply := token.VotingSupply() + if supply != expected { + t.Errorf("XGNS voting supply mismatch: got %d, want %d", supply, expected) + } +} + +// Example test using helper functions +func TestXGNSHelpers(t *testing.T) { + token, testAddr := setupXGNSTest(t) + testUser := testAddr + + // Test minting + xgnsMint(t, token, testUser, 1000) + xgnsCheckBalance(t, token, testUser, 1000) + xgnsCheckTotalSupply(t, token, 1000) + + // Test burning + xgnsBurn(t, token, testUser, 500) + xgnsCheckBalance(t, token, testUser, 500) + xgnsCheckTotalSupply(t, token, 500) + + // Test Voting Supply + launchpadAddr := consts.LAUNCHPAD_ADDR + xgnsMint(t, token, launchpadAddr, 200) + xgnsCheckVotingSupply(t, token, 500) // Only testUser's balance counts +} + +// Additional helper that combines with existing TokenTeller mock +func createMockXGNSTeller() *MockTokenTeller { + return &MockTokenTeller{ + balance: 1000, + addr: testutils.TestAddress("mock_xgns"), + } +} diff --git a/contract/r/gnoswap/launchpad/api_deposit.gno b/contract/r/gnoswap/launchpad/api_deposit.gno new file mode 100644 index 000000000..d2c6eeda2 --- /dev/null +++ b/contract/r/gnoswap/launchpad/api_deposit.gno @@ -0,0 +1,106 @@ +package launchpad + +import ( + "std" + "strings" + + "gno.land/p/demo/ufmt" +) + +func ApiGetClaimableDepositByAddress(address std.Address) uint64 { + if !address.IsValid() { + return 0 + } + + gnsToUser := uint64(0) + for _, depositId := range depositsByUser[address] { + deposit := deposits[depositId] + + project, exist := projects[deposit.projectId] + if !exist { + continue + } + + tier := getTier(project, deposit.tier) + if checkTierActive(project, tier) { + continue + } + + if deposit.depositCollectHeight != 0 { + continue + } + + rwd, _ := rewardStates.Get(deposit.projectId, deposit.tier) + reward := rwd.Claim(depositId, uint64(std.GetHeight())) + + gnsToUser += reward + } + + return gnsToUser +} + +func ApiGetDepositByDepositId(depositId string) string { + deposit, exist := deposits[depositId] + if !exist { + return "" + } + + builder := MetaBuilder(). + WriteString("depositId", depositId) + DepositBuilder(builder, deposit) + + return marshal(builder.Node()) +} + +func ApiGetDepositFullByDepositId(depositId string) string { + deposit, exist := deposits[depositId] + if !exist { + return "" + } + + project, exist := projects[deposit.projectId] + if !exist { + return "" + } + + var tier Tier + switch deposit.tier { + case "30": + tier = project.tiers[30] + case "90": + tier = project.tiers[90] + case "180": + tier = project.tiers[180] + } + + builder := MetaBuilder(). + WriteString("depositId", depositId) + + // Add project info + ProjectBuilder(builder, project) + + // Add tier info + TierBuilder(builder, "", tier) + + // Add deposit info + DepositBuilder(builder, deposit) + + return marshal(builder.Node()) +} + +func makeConditionsToStr(conditions map[string]Condition) (string, string) { + var tokenPathList string + var amountList string + + for tokenPath, condition := range conditions { + // append with *PAD*, except last one + tokenPathList += tokenPath + PAD_SEP + amountList += ufmt.Sprintf("%d", condition.minAmount) + PAD_SEP + } + + // remove last *PAD* + tokenPathList = strings.TrimSuffix(tokenPathList, PAD_SEP) + amountList = strings.TrimSuffix(amountList, PAD_SEP) + + return tokenPathList, amountList +} diff --git a/contract/r/gnoswap/launchpad/api_project.gno b/contract/r/gnoswap/launchpad/api_project.gno new file mode 100644 index 000000000..36a742f25 --- /dev/null +++ b/contract/r/gnoswap/launchpad/api_project.gno @@ -0,0 +1,62 @@ +package launchpad + +import ( + "gno.land/p/demo/ufmt" +) + +func ApiGetProjectAndTierStatisticsByProjectId(projectId string) string { + project, exist := projects[projectId] + if !exist { + return "" + } + + builder := MetaBuilder(). + WriteString("projectId", projectId) + ProjectBuilder(builder, project) + + return marshal(builder.Node()) +} + +func ApiGetProjectStatisticsByProjectId(projectId string) string { + project, exist := projects[projectId] + if !exist { + return "" + } + + builder := MetaBuilder(). + WriteString("projectId", projectId) + ProjectBuilder(builder, project) + + return marshal(builder.Node()) +} + +func ApiGetProjectStatisticsByProjectTierId(tierId string) string { + projectId, tierStr := splitProjectIdAndTier(tierId) + project, exist := projects[projectId] + if !exist { + return "" + } + + var tier Tier + switch tierStr { + case "30": + tier = project.tiers[30] + case "90": + tier = project.tiers[90] + case "180": + tier = project.tiers[180] + default: + return "" + } + + builder := MetaBuilder(). + WriteString("projectId", projectId). + WriteString("tierId", tierId). + WriteString("tierAmount", ufmt.Sprintf("%d", tier.tierAmount)). + WriteString("tierTotalDepositAmount", ufmt.Sprintf("%d", tier.totalDepositAmount)). + WriteString("tierActualDepositAmount", ufmt.Sprintf("%d", tier.actualDepositAmount)). + WriteString("tierTotalParticipant", ufmt.Sprintf("%d", tier.totalParticipant)). + WriteString("tierActualParticipant", ufmt.Sprintf("%d", tier.actualParticipant)) + + return marshal(builder.Node()) +} diff --git a/contract/r/gnoswap/launchpad/api_reward.gno b/contract/r/gnoswap/launchpad/api_reward.gno new file mode 100644 index 000000000..9ecec390c --- /dev/null +++ b/contract/r/gnoswap/launchpad/api_reward.gno @@ -0,0 +1,66 @@ +package launchpad + +import ( + "std" + + gs "gno.land/r/gnoswap/v1/gov/staker" +) + +// protocol_fee reward for project's recipient +func ApiGetProjectRecipientRewardByProjectId(projectId string) string { + project, exist := projects[projectId] + if !exist { + return "0" + } + + return gs.GetClaimableRewardByAddress(project.recipient) +} + +func ApiGetProjectRecipientRewardByAddress(address std.Address) string { + if !address.IsValid() { + return "0" + } + + return gs.GetClaimableRewardByAddress(address) +} + +// project reward for deposit +func ApiGetDepositRewardByDepositId(depositId string) uint64 { + deposit, exist := deposits[depositId] + if !exist { + return 0 + } + + rewardState, err := rewardStates.Get(deposit.projectId, deposit.tier) + if err != nil { + return 0 + } + + return rewardState.CalculateReward(depositId) +} + +func ApiGetDepositRewardByAddress(address std.Address) uint64 { + if !address.IsValid() { + return 0 + } + + depositIds, exist := depositsByUser[address] + if !exist { + return 0 + } + + totalReward := uint64(0) + for _, depositId := range depositIds { + deposit, exist := deposits[depositId] + if !exist { + continue + } + rewardState, err := rewardStates.Get(deposit.projectId, deposit.tier) + if err != nil { + continue + } + totalReward += rewardState.CalculateReward(depositId) + } + + return totalReward +} diff --git a/contract/r/gnoswap/launchpad/consts.gno b/contract/r/gnoswap/launchpad/consts.gno new file mode 100644 index 000000000..908d432e2 --- /dev/null +++ b/contract/r/gnoswap/launchpad/consts.gno @@ -0,0 +1,24 @@ +package launchpad + +const ( + minimumGnsAmount = 1_000_000 +) + +const ( + TIMESTAMP_180DAYS = uint64(180 * 24 * 60 * 60) // 15552000 + TIMESTAMP_14DAYS = uint64(14 * 24 * 60 * 60) // 1209600 + TIMESTAMP_90DAYS = uint64(90 * 24 * 60 * 60) // 7776000 + TIMESTAMP_7DAYS = uint64(7 * 24 * 60 * 60) // 604800 + TIMESTAMP_30DAYS = uint64(30 * 24 * 60 * 60) // 2592000 + TIMESTAMP_3DAYS = uint64(3 * 24 * 60 * 60) // 259200 + TIMESTAMP_DAY = uint64(24 * 60 * 60) // 86400 + + TIER30 = uint64(30) + TIER90 = uint64(90) + TIER180 = uint64(180) + TOTAL_TIERS = 3 + + PAD_SEP = "*PAD*" + + MINIMUM_DEPOSIT_AMOUNT = uint64(1_000_000) +) diff --git a/contract/r/gnoswap/launchpad/deposit.gno b/contract/r/gnoswap/launchpad/deposit.gno new file mode 100644 index 000000000..0dff13aca --- /dev/null +++ b/contract/r/gnoswap/launchpad/deposit.gno @@ -0,0 +1,918 @@ +package launchpad + +import ( + "errors" + "std" + "strconv" + "time" + + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + en "gno.land/r/gnoswap/v1/emission" + "gno.land/r/gnoswap/v1/gns" + gs "gno.land/r/gnoswap/v1/gov/staker" + "gno.land/r/gnoswap/v1/gov/xgns" +) + +var ( + // depositId -> deposit + deposits = make(map[string]Deposit) + + // project -> tier -> []depositId + depositsByProject = make(map[string]map[string][]string) + + // user -> []depositId + depositsByUser = make(map[std.Address][]string) + + // user -> project -> []depositId + depositsByUserByProject = make(map[std.Address]map[string][]string) + + // depositState contains all deposit-related indices + depositState DepositState +) + +// DepositState represents the state of the deposits +type DepositState struct { + Deposits map[string]Deposit + DepositsByProject map[string]map[string][]string + DepositsByUser map[std.Address][]string + DepositsByUserProject map[std.Address]map[string][]string +} + +// ProjectTierInfo contains information about a project tier +type ProjectTierInfo struct { + Project Project + Tier Tier + TierType string + CurrentHeight uint64 + CurrentTime uint64 +} + +// DepositGns handles the deposit process for a specific project tier. +// +// This function validates the deposit conditions, checks the activation status of the project and tier, +// and creates a new deposit for the caller. It updates the project, tier, and reward states accordingly. +// +// Parameters: +// - targetProjectTierId (string): The ID of the target project and tier in the format "{projectId}:{tierType}". +// - amount (uint64): The amount to deposit. +// +// Returns: +// - string: The ID of the newly created deposit. +// +// Panics: +// - If the system is halted. +// - If the amount is invalid (e.g., zero). +// - If the project or tier is inactive. +// - If the deposit conditions are not satisfied. +// - If an error occurs during deposit creation or reward state update. +// +// Logic: +// 1. **Validation**: +// - Check if the system is halted or the deposit amount is valid. +// - Retrieve the project and validate its activation status. +// - Validate deposit conditions for the caller. +// - Retrieve and validate the tier's activation status. +// +// 2. **Deposit Creation**: +// - Create a deposit using the project and tier information. +// - Update deposit indices and tier details (e.g., `StartHeight`, `TierAmountPerBlockX128`). +// +// 3. **Reward Updates**: +// - Update the tier's reward state, including reward per deposit and participant counts. +// - Update the project's stats (e.g., total deposit amount, participant count). +// +// 4. **Transfer and Lock Tokens**: +// - Transfer the deposit amount to the `launchpad` contract and mint `xGNS` tokens. +// +// 5. **Emit Events**: +// - Emit events for both the deposit and the first deposit for the tier (if applicable). +// +// ref: https://docs.gnoswap.io/contracts/launchpad/launchpad_deposit.gno#depositgns +func DepositGns(targetProjectTierId string, amount uint64) string { + assertOnlyNotHalted() + assertValidAmount(amount) + + projectId, tierType := splitProjectIdAndTier(targetProjectTierId) + project, err := getProject(projectId) + if err != nil { + panic(err.Error()) + } + + currentHeight := uint64(std.GetHeight()) + if !project.isActivated(currentHeight) { + panic(addDetailToError(errInactiveProject, projectId)) + } + caller := getPrevAddr() + project.checkDepositConditions(caller) + + tier, err := project.Tier(convertTierTypeStrToUint64(tierType)) + if err != nil { + panic(err.Error()) + } + + if !tier.isActivated(currentHeight) { + panic(addDetailToError(errInactiveTier, tierType)) + } + + en.MintAndDistributeGns() + + now := uint64(time.Now().Unix()) + info := ProjectTierInfo{ + Project: project, + Tier: tier, + TierType: tierType, + CurrentHeight: currentHeight, + CurrentTime: now, + } + deposit, err := createDeposit(info, amount) + if err != nil { + panic(addDetailToError(errInvalidInput, err.Error())) + } + updateDeposit(deposit) + + if tier.isFirstDeposit() { + // update StartHeight, StartTime, TierAmountPerBlockX128 + tier.updateStarted(info.CurrentHeight, info.CurrentTime) + duration := tier.Ended().height - info.CurrentHeight + tier.SetTierAmountPerBlockX128(calcRewardPerBlockX128(tier.TierAmount(), duration)) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "FirstDepositForProjectTier", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "targetProjectTierId", targetProjectTierId, + "amount", formatUint(amount), + "depositId", deposit.id, + "claimableHeight", formatUint(deposit.claimableHeight), + "claimableTime", formatUint(deposit.claimableTime), + "tierAmountPerBlockX128", tier.TierAmountPerBlockX128().ToString(), + ) + + rewardState := NewRewardState(tier.tierAmountPerBlockX128, tier.started.height, tier.ended.height) + rewardStates.Set(info.Project.id, info.TierType, rewardState) + } + + reward := tier.Reward() + rewardPerDeposit, err := reward.calculateRewardPerDeposit(tier.TierAmountPerBlockX128(), tier.ActualDepositAmount()) + if err != nil { + panic(err.Error()) + } + reward.addRewardPerDeposit(rewardPerDeposit) + reward.SetLastHeight(currentHeight) + + rewardInfo := NewRewardInfo( + reward.AccumRewardPerDeposit(), + deposit.Amount(), + 0, + 0, + deposit.DepositHeight(), + tier.Ended().height, + currentHeight) + reward.Info().Set(deposit.ID(), rewardInfo) + tier.SetReward(reward) + tier.SetTotalDepositAmount(tier.TotalDepositAmount() + amount) + tier.SetActualDepositAmount(tier.ActualDepositAmount() + amount) + tier.SetTotalParticipant(tier.TotalParticipant() + 1) + tier.SetActualParticipant(tier.ActualParticipant() + 1) + + project.setTier(convertTierTypeStrToUint64(tierType), tier) + project.Stats().SetTotalDeposit(project.Stats().TotalDeposit() + amount) + project.Stats().SetActualDeposit(project.Stats().ActualDeposit() + amount) + project.Stats().SetTotalParticipant(project.Stats().TotalParticipant() + 1) + project.Stats().SetActualParticipant(project.Stats().ActualParticipant() + 1) + + projects[projectId] = project + + rewardState, err := rewardStates.Get(projectId, tierType) + if err != nil { + panic(addDetailToError( + errInvalidRewardState, err.Error())) + } + rewardState.AddStake(uint64(std.GetHeight()), deposit.id, amount) + + // Update gov_staker to calculate project's recipient's reward + gs.SetAmountByProjectWallet(project.recipient, amount, true) + + // gns will be locked in `launchpad` contract + gns.TransferFrom( + getPrevAddr(), + consts.LAUNCHPAD_ADDR, + amount, + ) + xgns.Mint(consts.LAUNCHPAD_ADDR, amount) + + // Emit deposit event + prevAddr, prevPkgPath := getPrev() + std.Emit( + "DepositGns", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "targetProjectTierId", targetProjectTierId, + "amount", strconv.FormatUint(amount, 10), + "depositId", deposit.id, + "claimableHeight", strconv.FormatUint(deposit.claimableHeight, 10), + "claimableTime", strconv.FormatUint(deposit.claimableTime, 10), + ) + + return deposit.id +} + +// CollectDepositGns collects rewards from all deposits associated with the caller. +// +// This function retrieves all deposits for the caller, validates their claimable status, +// and processes the collection of rewards. +// +// Returns: +// - uint64: The total amount of rewards collected. +// +// Panics: +// - If the caller has no deposits. +// - If an error occurs during the reward collection process. +func CollectDepositGns() uint64 { + caller := getPrevAddr() + depositList, ok := depositsByUser[caller] + if depositList == nil || !ok || len(depositList) == 0 { + panic(addDetailToError( + errNotExistDeposit, + ufmt.Sprintf("no deposit list for user(%s)", caller), + )) + } + + return processCollectedDeposits(depositList, "") +} + +// CollectDepositGnsByProjectId collects rewards from all deposits associated with a specific project. +// +// This function retrieves all deposits for the caller under the specified project ID, +// validates their claimable status, and processes the collection of rewards. +// +// Parameters: +// - projectId (string): The ID of the project. +// +// Returns: +// - uint64: The total amount of rewards collected for the project. +// +// Panics: +// - If the project ID is invalid. +// - If the caller has no deposits in the specified project. +// - If an error occurs during the reward collection process. +func CollectDepositGnsByProjectId(projectId string) uint64 { + assertValidProjectId(projectId) + + _, err := getProject(projectId) + if err != nil { + panic(err.Error()) + } + + caller := getPrevAddr() + if _, exists := depositsByUserByProject[caller]; !exists { + panic(addDetailToError(errInvalidCaller, ufmt.Sprintf("caller(%s) not found", caller))) + } + depositList, exists := depositsByUserByProject[caller][projectId] + if depositList == nil || !exists { + panic(addDetailToError( + errNotExistDeposit, ufmt.Sprintf("no deposit list for project(%s)", projectId))) + } + + return processCollectedDeposits(depositList, projectId) +} + +// CollectDepositGnsByDepositId collects rewards from a specific deposit by ID. +// +// This function retrieves a specific deposit by its ID, validates its claimable status, +// and processes the collection of rewards. +// +// Parameters: +// - depositId (string): The ID of the deposit to collect rewards from. +// +// Returns: +// - uint64: The amount of rewards collected. +// +// Panics: +// - If the system is halted. +// - If the deposit ID is invalid. +// - If the deposit does not belong to the caller. +// - If the deposit is not yet claimable or has already been collected. +// - If an error occurs during the reward collection process. +func CollectDepositGnsByDepositId(depositId string) uint64 { + assertOnlyNotHalted() + assertValidProjectId(depositId) + + caller := getPrevAddr() + + deposit, project, err := validateDepositCollection(depositId, caller) + if err != nil { + panic(addDetailToError(errInvalidInput, err.Error())) + } + + en.MintAndDistributeGns() + + project = projects[deposit.projectId] + + // Process deposit collection + height := uint64(std.GetHeight()) + amount, err := processDeposit(&deposit, &project, height) + if err != nil { + panic(err.Error()) + } + + // Update gov_staker contract + gs.SetAmountByProjectWallet(project.recipient, amount, false) + + // Emit collection event + prevAddr, prevPkgPath := getPrev() + std.Emit( + "CollectDepositGnsByDepositId", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "depositId", depositId, + "amount", strconv.FormatUint(amount, 10), + ) + + // Process token transfers if amount > 0 + if amount > 0 { + xgns.Burn(consts.LAUNCHPAD_ADDR, amount) + gns.Transfer(caller, amount) + return amount + } + + return 0 +} + +// splitProjectIdAndTier extracts the project ID and tier from a given tier ID. +// +// This function splits the provided `tierId` into three parts using the `:` separator and +// returns the first two parts as the project ID and the third part as the tier. +// +// Parameters: +// - tierId (string): The tier ID in the format "{projectId}:{tier}". +// +// Returns: +// - string: The project ID derived from the first two parts of the `tierId`. +// - string: The tier derived from the third part of the `tierId`. +func splitProjectIdAndTier(tierId string) (string, string) { + res, err := common.Split(tierId, ":", 3) + if err != nil { + panic(addDetailToError( + errInvalidTier, + ufmt.Sprintf("invalid tierId: %s", tierId), + )) + } + + return ufmt.Sprintf("%s:%s", res[0], res[1]), res[2] +} + +// validateProjectTier validates a project and its tier +func validateProjectTier(projectId string, tierStr string) (Project, Tier, error) { + project, exists := projects[projectId] + if !exists { + return Project{}, Tier{}, errors.New(ufmt.Sprintf("project(%s) not found", projectId)) + } + + currentHeight := uint64(std.GetHeight()) + + project.Tiers() + + tier := getTier(project, tierStr) + if !isProjectActive(project, currentHeight) { + return Project{}, Tier{}, errors.New("project is not active") + } + + if !isTierActive(tier, currentHeight) { + return Project{}, Tier{}, errors.New("tier is not active") + } + + return project, tier, nil +} + +func getProjectIdFromTierId(tierId string) string { + // input: gno.land/r/gnoswap/gns:123:30 + // output: gno.land/r/gnoswap/gns:123 + res, err := common.Split(tierId, ":", 3) + if err != nil { + panic(addDetailToError( + errInvalidTier, + ufmt.Sprintf("invalid tierId: %s", tierId), + )) + } + + return ufmt.Sprintf("%s:%s", res[0], res[1]) +} + +// checkDepositConditions validates whether the caller meets the deposit conditions of the given project. +// +// This function checks if the caller satisfies all conditions required for depositing into the specified project. +// Each condition specifies a minimum token balance that must be met. If any condition is not met, the function +// panics with an appropriate error message. +// +// Parameters: +// - project (Project): The project containing the deposit conditions to validate. +func checkDepositConditions(project Project) { + if project.conditions == nil { + return + } + + caller := getPrevAddr() + for _, condition := range project.conditions { + if condition.minAmount == 0 { + continue + } + + // check balance + var balance uint64 + if condition.tokenPath == consts.GOV_XGNS_PATH { + balance = xgns.BalanceOf(caller) + } else { + balance = common.BalanceOf(condition.tokenPath, caller) + } + + if balance < condition.minAmount { + panic(addDetailToError( + errNotEnoughBalance, + ufmt.Sprintf("insufficient balance(%d) for token(%s)", balance, condition.tokenPath), + )) + } + } +} + +// checkTierActive checks whether a specific tier in a project is active. +// +// This function verifies if the tier's end height is greater than or equal to the current block height. +// +// Parameters: +// - project (Project): The project to which the tier belongs. +// - tier (Tier): The tier to check. +// +// Returns: +// - bool: True if the tier is active, otherwise false. +func checkTierActive(project Project, tier Tier) bool { + if tier.ended.height < uint64(std.GetHeight()) { + return false + } + + return true +} + +// convertTierTypeStrToUint64 converts a tier type string to its corresponding uint64 value. +// +// This function maps tier strings ("30", "90", "180") to predefined constants (`TIER30`, `TIER90`, `TIER180`). +// If the input string is invalid, the function panics. +// +// Parameters: +// - tierType (string): The tier type as a string. +// +// Returns: +// - uint64: The corresponding uint64 constant for the tier type. +// +// Panics: +// - If the tierType is invalid, an error is raised. +func convertTierTypeStrToUint64(tierType string) uint64 { + switch tierType { + case "30": + return TIER30 + case "90": + return TIER90 + case "180": + return TIER180 + default: + panic(addDetailToError( + errInvalidTier, + ufmt.Sprintf("invalid tierType: %s", tierType), + )) + } +} + +// getTier retrieves a specific tier from a project based on the tier string. +// +// This function fetches the tier object corresponding to the input string ("30", "90", "180"). +// If the input string is invalid, the function panics. +// +// Parameters: +// - project (Project): The project containing the tiers. +// - tierStr (string): The tier identifier as a string ("30", "90", "180"). +// +// Returns: +// - Tier: The tier object corresponding to the specified tier string. +// +// Panics: +// - If the tierStr is invalid, an error is raised. +func getTier(project Project, tierStr string) Tier { + switch tierStr { + case "30": + return project.tiers[TIER30] + case "90": + return project.tiers[TIER90] + case "180": + return project.tiers[TIER180] + default: + panic(addDetailToError( + errInvalidTier, + ufmt.Sprintf("invalid tierStr: %s", tierStr), + )) + } +} + +// setTier updates the tier information for a specific project. +// +// This function sets the tier data for the specified tier string ("30", "90", "180"). +// If the input string is invalid, the function panics. +// +// Parameters: +// - project (Project): The project to update. +// - tierStr (string): The tier identifier as a string ("30", "90", "180"). +// - tier (Tier): The tier object to set. +// +// Returns: +// - Project: The updated project object with the modified tier. +// +// Panics: +// - If the tierStr is invalid, an error is raised. +func setTier(project Project, tierStr string, tier Tier) Project { + switch tierStr { + case "30": + project.tiers[TIER30] = tier + case "90": + project.tiers[TIER90] = tier + case "180": + project.tiers[TIER180] = tier + default: + panic(addDetailToError( + errInvalidTier, + ufmt.Sprintf("invalid tierStr: %s", tierStr), + )) + } + + return project +} + +// isPassedClaimableHeight checks if a deposit has reached its claimable block height. +// +// Parameters: +// - deposit (Deposit): The deposit to check. +// - height (uint64): The current block height. +// +// Returns: +// - bool: True if the deposit's claimable height has been reached, otherwise false. +func isPassedClaimableHeight(deposit Deposit, height uint64) bool { + return deposit.claimableHeight <= height +} + +// isAlreadyCollected checks if a deposit has already been collected. +// +// Parameters: +// - deposit (Deposit): The deposit to check. +// +// Returns: +// - bool: True if the deposit has already been collected, otherwise false. +func isAlreadyCollected(deposit Deposit) bool { + return deposit.depositCollectHeight != 0 +} + +// isProjectActive checks if a project is currently active based on the block height. +// +// Parameters: +// - project (Project): The project to check. +// - height (uint64): The current block height. +// +// Returns: +// - bool: True if the project is active, otherwise false. +func isProjectActive(project Project, height uint64) bool { + return project.started.height <= height && height <= project.ended.height +} + +// isTierActive checks if a specific tier is currently active based on the block height. +// +// Parameters: +// - tier (Tier): The tier to check. +// - height (uint64): The current block height. +// +// Returns: +// - bool: True if the tier is active, otherwise false. +func isTierActive(tier Tier, height uint64) bool { + return height <= tier.ended.height +} + +// createDeposit creates a new deposit for a specific project tier and depositor. +// +// This function generates a unique deposit ID, calculates the claimable height and time, +// and initializes a `Deposit` object with the provided details. +// +// Parameters: +// - info (ProjectTierInfo): Information about the project tier, including project details, +// tier string, current block height, and current time. +// - amount (uint64): The amount to be deposited. +// +// Returns: +// - Deposit: The newly created `Deposit` object containing all relevant details. +// - error: Returns an error if the deposit creation fails (currently always returns `nil`). +func createDeposit(info ProjectTierInfo, amount uint64) (Deposit, error) { + // Validate input parameters + if amount == 0 { + return Deposit{}, errors.New(ufmt.Sprintf("deposit amount cannot be zero")) + } + if info.Project.id == "" || info.TierType == "" { + return Deposit{}, errors.New(ufmt.Sprintf("invalid project or tier information")) + } + + depositor := getPrevAddr() + depositId := generateDepositId(info, depositor) + claimableHeight, claimableTime := calculateClaimableTimes(info) + + deposit := Deposit{ + id: depositId, + projectId: info.Project.id, + tier: info.TierType, + depositor: depositor, + amount: amount, + depositHeight: info.CurrentHeight, + depositTime: info.CurrentTime, + claimableHeight: claimableHeight, + claimableTime: claimableTime, + } + + return deposit, nil +} + +func generateDepositId(info ProjectTierInfo, depositor std.Address) string { + return ufmt.Sprintf("%s:%s:%s:%d", info.Project.id, info.TierType, depositor.String(), info.CurrentHeight) +} + +// calculateClaimableTimes calculates the claimable height and time for a deposit +func calculateClaimableTimes(info ProjectTierInfo) (uint64, uint64) { + var waitDuration uint64 + switch info.TierType { + case "30": + waitDuration = TIMESTAMP_3DAYS + case "90": + waitDuration = TIMESTAMP_7DAYS + case "180": + waitDuration = TIMESTAMP_14DAYS + } + + calcHeight := info.CurrentHeight + info.Tier.collectWaitDuration + calcTime := info.CurrentTime + waitDuration + + return minU64(calcHeight, info.Tier.ended.height), + minU64(calcTime, info.Tier.ended.time) +} + +func updateDeposit(deposit Deposit) { + // Update deposits + _, ok := deposits[deposit.id] + if !ok { + deposits[deposit.id] = deposit + } else { + panic(addDetailToError( + errDuplicateDeposit, + ufmt.Sprintf("deposit(%s) already exists", deposit.id), + )) + } + + // Update depositsByUser + if depositsByUser[deposit.depositor] == nil { + depositsByUser[deposit.depositor] = []string{} + } + depositsByUser[deposit.depositor] = append(depositsByUser[deposit.depositor], deposit.id) + + // Update depositsByProject + if depositsByProject[deposit.projectId] == nil { + depositsByProject[deposit.projectId] = make(map[string][]string) + } + if depositsByProject[deposit.projectId][deposit.tier] == nil { + depositsByProject[deposit.projectId][deposit.tier] = []string{} + } + depositsByProject[deposit.projectId][deposit.tier] = append( + depositsByProject[deposit.projectId][deposit.tier], + deposit.id, + ) + + // Update depositsByUserByProject + if depositsByUserByProject[deposit.depositor] == nil { + depositsByUserByProject[deposit.depositor] = make(map[string][]string) + } + if depositsByUserByProject[deposit.depositor][deposit.projectId] == nil { + depositsByUserByProject[deposit.depositor][deposit.projectId] = []string{} + } + depositsByUserByProject[deposit.depositor][deposit.projectId] = append( + depositsByUserByProject[deposit.depositor][deposit.projectId], + deposit.id, + ) +} + +// updateDepositIndices updates all deposit-related indices +func updateDepositIndices(deposit Deposit, state *DepositState) { + // Update deposits + _, ok := state.Deposits[deposit.id] + if !ok { + state.Deposits[deposit.id] = deposit + } else { + panic(addDetailToError( + errDuplicateDeposit, + ufmt.Sprintf("deposit(%s) already exists", deposit.id), + )) + } + + // Update depositsByUser + if state.DepositsByUser[deposit.depositor] == nil { + state.DepositsByUser[deposit.depositor] = []string{} + } + state.DepositsByUser[deposit.depositor] = append( + state.DepositsByUser[deposit.depositor], + deposit.id, + ) + + // Update depositsByProject + if state.DepositsByProject[deposit.projectId] == nil { + state.DepositsByProject[deposit.projectId] = make(map[string][]string) + } + if state.DepositsByProject[deposit.projectId][deposit.tier] == nil { + state.DepositsByProject[deposit.projectId][deposit.tier] = []string{} + } + state.DepositsByProject[deposit.projectId][deposit.tier] = append( + state.DepositsByProject[deposit.projectId][deposit.tier], + deposit.id, + ) + + // Update depositsByUserProject + if state.DepositsByUserProject[deposit.depositor] == nil { + state.DepositsByUserProject[deposit.depositor] = make(map[string][]string) + } + if state.DepositsByUserProject[deposit.depositor][deposit.projectId] == nil { + state.DepositsByUserProject[deposit.depositor][deposit.projectId] = []string{} + } + state.DepositsByUserProject[deposit.depositor][deposit.projectId] = append( + state.DepositsByUserProject[deposit.depositor][deposit.projectId], + deposit.id, + ) +} + +// processDepositCollection processes the collection of deposits +func processDepositCollection( + depositList []string, + projectId string, +) (uint64, error) { + totalAmount := uint64(0) // gnsToUser + currentHeight := uint64(std.GetHeight()) + now := uint64(time.Now().Unix()) + prevAddr, prevPkgPath := getPrev() + + for _, depositId := range depositList { + deposit := deposits[depositId] + // Skip if deposit is already collected + if deposit.depositCollectHeight != 0 { + continue + } + + project, exists := projects[deposit.projectId] + if !exists || projectId != "" && project.id != projectId { + continue + } + + if currentHeight < project.Ended().height { + continue + } + + tier, err := project.Tier(convertTierTypeStrToUint64(deposit.tier)) + if err != nil { + panic(err.Error()) + } + if !isPassedClaimableHeight(deposit, currentHeight) || isAlreadyCollected(deposit) { + continue + } + + deposit.SetDepositCollectHeight(currentHeight) + deposit.SetDepositCollectTime(now) + deposits[deposit.id] = deposit + + totalAmount += deposit.amount + + gs.SetAmountByProjectWallet(project.recipient, deposit.amount, false) + + tier.SetActualDepositAmount(tier.ActualDepositAmount() - deposit.amount) + tier.SetActualParticipant(tier.ActualParticipant() - 1) + + project.setTier(convertTierTypeStrToUint64(deposit.tier), tier) + project.Stats().SetActualDeposit(project.Stats().actualDeposit - deposit.amount) + project.Stats().SetActualParticipant(project.Stats().actualParticipant - 1) + + projects[deposit.projectId] = project + + if projectId != "" { + std.Emit( + "CollectDepositGnsByProjectId", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "projectId", projectId, + "depositId", depositId, + "amount", strconv.FormatUint(deposit.amount, 10), + ) + } else { + std.Emit( + "CollectDepositGns", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "depositId", depositId, + "amount", strconv.FormatUint(deposit.amount, 10), + ) + } + } + + if totalAmount == 0 { + return 0, errors.New("no deposit to collect") + } + + return totalAmount, nil +} + +func processCollectedDeposits(depositList []string, pid string) uint64 { + assertOnlyNotHalted() + en.MintAndDistributeGns() + + amount, err := processDepositCollection(depositList, pid) + if err != nil { + panic(err) + } + + if amount > 0 { + caller := getPrevAddr() + xgns.Burn(consts.LAUNCHPAD_ADDR, amount) + gns.Transfer(caller, amount) + } + + return amount +} + +func validateDepositCollection(depositId string, caller std.Address) (Deposit, Project, error) { + deposit, exists := deposits[depositId] + if !exists { + return Deposit{}, Project{}, errors.New("deposit not found") + } + + project, exists := projects[deposit.projectId] + if !exists { + return Deposit{}, Project{}, errors.New("project not found") + } + + projMap, userHasProject := depositsByUserByProject[caller] + if !userHasProject { + return Deposit{}, Project{}, errors.New("user has no deposits in any project") + } + + depositIds, userHasDepositsInProject := projMap[deposit.projectId] + if !userHasDepositsInProject { + return Deposit{}, Project{}, errors.New("user has no deposits in this project") + } + + if !containsString(depositIds, depositId) { + return Deposit{}, Project{}, errors.New("user has no deposit with this id") + } + + currentHeight := uint64(std.GetHeight()) + + if !isPassedClaimableHeight(deposit, currentHeight) { + return Deposit{}, Project{}, errors.New("not passed claimable block height yet") + } + + if isAlreadyCollected(deposit) { + return Deposit{}, Project{}, errors.New("deposit already collected") + } + + return deposit, project, nil +} + +// receive deposit and project as a pointer to avoid copying +func processDeposit(deposit *Deposit, project *Project, height uint64) (uint64, error) { + tier := getTier(*project, deposit.tier) + //if isTierActive(tier, height) { + // return 0, errors.New(addDetailToError(errActiveProject, "tier is still active")) + //} + tier, err := project.Tier(convertTierTypeStrToUint64(deposit.tier)) + if err != nil { + return 0, err + } + + if deposit.depositCollectHeight != 0 { + return 0, errors.New(addDetailToError(errAlreadyCollected, ufmt.Sprintf("depositId (%s)", deposit.id))) + } + + // Update deposit status + deposit.SetDepositCollectHeight(height) + deposit.SetDepositCollectTime(uint64(time.Now().Unix())) + deposits[deposit.id] = *deposit + + // Update project tier + tier.SetActualDepositAmount(tier.ActualDepositAmount() - deposit.amount) + tier.SetActualParticipant(tier.ActualParticipant() - 1) + + // Update project + project.setTier(convertTierTypeStrToUint64(deposit.tier), tier) + project.Stats().SetActualDeposit(project.Stats().actualDeposit - deposit.amount) + project.Stats().SetActualParticipant(project.Stats().actualParticipant - 1) + + projects[deposit.projectId] = *project + + return deposit.amount, nil +} diff --git a/contract/r/gnoswap/launchpad/deposit_test.gno b/contract/r/gnoswap/launchpad/deposit_test.gno new file mode 100644 index 000000000..2e40315c1 --- /dev/null +++ b/contract/r/gnoswap/launchpad/deposit_test.gno @@ -0,0 +1,788 @@ +package launchpad + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/gnoswap/v1/gov/xgns" +) + +// Mock data structs +type MockTokenTeller struct { + balance uint64 + addr std.Address +} + +func (m *MockTokenTeller) TransferFrom(from std.Address, to std.Address, amount uint64) { + m.balance = amount + m.addr = to +} + +func (m *MockTokenTeller) BalanceOf(addr std.Address) uint64 { + return m.balance +} + +func setupTestDeposit(t *testing.T) (*MockTokenTeller, std.Address) { + t.Helper() + projects = make(map[string]Project) + deposits = make(map[string]Deposit) + depositsByProject = make(map[string]map[string][]string) + depositsByUser = make(map[std.Address][]string) + depositsByUserByProject = make(map[std.Address]map[string][]string) + + mockTeller := &MockTokenTeller{balance: 1000} + testAddr := testutils.TestAddress("test") + + return mockTeller, testAddr +} + +func createTestProject(t *testing.T) Project { + now := uint64(time.Now().Unix()) + height := uint64(std.GetHeight()) + + tiers := make(map[uint64]Tier) + tiersRatios := make(map[uint64]uint64) + + tiers[30] = Tier{ + id: ufmt.Sprintf("gno.land/r/gnoswap/v1/gns:%d:30", height), + collectWaitDuration: TIMESTAMP_3DAYS * 1000 / 1000, + tierAmount: 300, + started: TimeInfo{ + height: height, + time: now, + }, + ended: TimeInfo{ + height: height + (30 * 24 * 60 * 60), + time: now + (30 * 24 * 60 * 60), + }, + reward: *NewReward(u256.Zero(), 0, 0), + } + tiersRatios[30] = 30 + + tiers[90] = Tier{ + id: ufmt.Sprintf("gno.land/r/gnoswap/v1/gns:%d:90", height), + collectWaitDuration: TIMESTAMP_7DAYS * 1000 / 1000, + tierAmount: 300, + started: TimeInfo{ + height: height, + time: now, + }, + ended: TimeInfo{ + height: height + (90 * 24 * 60 * 60), + time: now + (90 * 24 * 60 * 60), + }, + reward: *NewReward(u256.Zero(), 0, 0), + } + tiersRatios[90] = 30 + + tiers[180] = Tier{ + id: ufmt.Sprintf("gno.land/r/gnoswap/v1/gns:%d:180", height), + collectWaitDuration: TIMESTAMP_14DAYS * 1000 / 1000, + tierAmount: 400, + started: TimeInfo{ + height: height, + time: now, + }, + ended: TimeInfo{ + height: height + (180 * 24 * 60 * 60), + time: now + (180 * 24 * 60 * 60), + }, + reward: *NewReward(u256.Zero(), 0, 0), + } + tiersRatios[180] = 40 + + return Project{ + id: ufmt.Sprintf("gno.land/r/gnoswap/v1/gns:%d", height), + name: "Test Project", + tokenPath: "gno.land/r/gnoswap/v1/gns", + depositAmount: 1000, + recipient: testutils.TestAddress("recipient"), + conditions: make(map[string]Condition), + tiers: tiers, + tiersRatios: tiersRatios, + created: TimeInfo{ + height: height, + time: now, + }, + started: TimeInfo{ + height: height, + time: now, + }, + ended: TimeInfo{ + height: height + (180 * 24 * 60 * 60), + time: now + (180 * 24 * 60 * 60), + }, + stats: ProjectStats{ + totalDeposit: 0, + actualDeposit: 0, + totalParticipant: 0, + actualParticipant: 0, + totalCollected: 0, + }, + refund: RefundInfo{ + amount: 0, + height: 0, + time: 0, + }, + } +} + +func TestValidateProjectTier(t *testing.T) { + mockTeller, _ := setupTestDeposit(t) + project := createTestProject(t) + projects[project.id] = project + + tests := []struct { + name string + projectId string + tierStr string + shouldError bool + }{ + { + name: "Valid project and tier", + projectId: project.id, + tierStr: "30", + shouldError: false, + }, + { + name: "Invalid project id", + projectId: "nonexistent", + tierStr: "30", + shouldError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := validateProjectTier(tt.projectId, tt.tierStr) + switch tt.shouldError { + case true: + uassert.Error(t, err) + case false: + uassert.NoError(t, err) + } + }) + } +} + +func TestCreateDeposit(t *testing.T) { + _, testAddr := setupTestDeposit(t) + project := createTestProject(t) + now := uint64(time.Now().Unix()) + height := uint64(std.GetHeight()) + + info := ProjectTierInfo{ + Project: project, + Tier: project.tiers[30], + TierType: "30", + CurrentHeight: height, + CurrentTime: now, + } + + deposit, err := createDeposit(info, 100) + if err != nil { + t.Fatalf("Failed to create deposit: %v", err) + } + + uassert.Equal(t, deposit.amount, uint64(100)) + uassert.Equal(t, deposit.depositHeight, height) + uassert.Equal(t, deposit.depositTime, now) + uassert.Equal(t, deposit.projectId, project.id) +} + +func TestCalculateClaimableTimes(t *testing.T) { + tests := []struct { + name string + info ProjectTierInfo + wantHeight uint64 + wantTime uint64 + }{ + { + name: "Tier 30 - Normal Case", + info: ProjectTierInfo{ + Project: Project{}, + Tier: Tier{ + collectWaitDuration: 100, + ended: TimeInfo{ + height: 1000, + time: 1000000, + }, + }, + TierType: "30", + CurrentHeight: 500, + CurrentTime: 500000, + }, + wantHeight: 600, // CurrentHeight(500) + collectWaitDuration(100) + wantTime: 759200, // CurrentTime(500000) + TIMESTAMP_3DAYS + }, + { + name: "Tier 90 - Exceeds End CurrentHeight/Time", + info: ProjectTierInfo{ + Project: Project{}, + Tier: Tier{ + collectWaitDuration: 200, + ended: TimeInfo{ + height: 550, + time: 600000, + }, + }, + TierType: "90", + CurrentHeight: 400, + CurrentTime: 400000, + }, + wantHeight: 550, // min(400+200, endHeight(550)) + wantTime: 600000, // min(400000+TIMESTAMP_7DAYS, endTime(600000)) + }, + { + name: "Tier 180 - Early Stage", + info: ProjectTierInfo{ + Project: Project{}, + Tier: Tier{ + collectWaitDuration: 300, + ended: TimeInfo{ + height: 2000, + time: 2000000, + }, + }, + TierType: "180", + CurrentHeight: 100, + CurrentTime: 100000, + }, + wantHeight: 400, // CurrentHeight(100) + collectWaitDuration(300) + wantTime: 1309600, // CurrentTime(100000) + TIMESTAMP_14DAYS + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotHeight, gotTime := calculateClaimableTimes(tt.info) + uassert.Equal(t, gotHeight, tt.wantHeight) + uassert.Equal(t, gotTime, tt.wantTime) + }) + } +} + +func TestCollectDeposit(t *testing.T) { + _, testAddr := setupTestDeposit(t) + project := createTestProject(t) + projects[project.id] = project + currentTime := uint64(time.Now().Unix()) + height := uint64(std.GetHeight()) + + // Create test deposit + info := ProjectTierInfo{ + Project: project, + Tier: project.tiers[30], + TierType: "30", + CurrentHeight: height, + CurrentTime: currentTime, + } + + deposit, err := createDeposit(info, 100) + if err != nil { + t.Fatalf("Failed to create deposit: %v", err) + } + + // Add deposit to indices + deposits[deposit.id] = deposit + updateDepositIndices(deposit, &DepositState{ + Deposits: make(map[string]Deposit), + DepositsByProject: depositsByProject, + DepositsByUser: depositsByUser, + DepositsByUserProject: depositsByUserByProject, + }) + + t.Run("Cannot collect before claimable height", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(testAddr)) + amount, _ := processDepositCollection([]string{deposit.id}, "") + uassert.Equal(t, uint64(0), amount) + }) + + t.Run("Can collect after claimable height", func(t *testing.T) { + // Skip to after claimable height + std.TestSkipHeights(int64(deposit.claimableHeight - height + 1)) + + std.TestSetRealm(std.NewUserRealm(testAddr)) + amount, errCode := processDepositCollection([]string{deposit.id}, "") + uassert.NoError(t, errCode) + uassert.Equal(t, uint64(100), amount) + }) +} + +func TestDepositGns(t *testing.T) { + _, testAddr := setupTestDeposit(t) + project := createTestProject(t) + projects[project.id] = project + + t.Run("Fail with insufficient balance", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(testAddr)) + + defer func() { + r := recover() + if r != nil { + uassert.Equal(t, "[GNOSWAP-LAUNCHPAD-017] invalid amount || amount(1000) should greater than minimum deposit amount(1000000)", r.(string)) + } + }() + + DepositGns(project.id+":30", 1000) + std.TestSkipHeights(1) + }) + + t.Run("Success with sufficient balance", func(t *testing.T) { + // transfer enough gns + std.TestSetRealm(adminRealm) + gns.Transfer(testAddr, 1000000) + + std.TestSetRealm(std.NewUserRealm(testAddr)) + gns.Approve(consts.LAUNCHPAD_ADDR, 1000000) + + std.TestSkipHeights(10) + depositId := DepositGns(project.id+":30", 1000000) + deposit := deposits[depositId] + + uassert.Equal(t, deposit.amount, uint64(1000000)) + uassert.Equal(t, deposit.depositor, testAddr) + }) +} + +func TestUpdateDepositIndices(t *testing.T) { + _, testAddr := setupTestDeposit(t) + project := createTestProject(t) + + deposit := Deposit{ + id: "test_deposit_1", + projectId: project.id, + tier: "30", + depositor: testAddr, + amount: 100, + } + + state := &DepositState{ + Deposits: make(map[string]Deposit), + DepositsByProject: make(map[string]map[string][]string), + DepositsByUser: make(map[std.Address][]string), + DepositsByUserProject: make(map[std.Address]map[string][]string), + } + + updateDepositIndices(deposit, state) + + // Check depositsByUser + userDeposits := state.DepositsByUser[testAddr] + uassert.Equal(t, 1, len(userDeposits)) + uassert.Equal(t, deposit.id, userDeposits[0]) + + // Check depositsByProject + projectDeposits := state.DepositsByProject[project.id][deposit.tier] + uassert.Equal(t, 1, len(projectDeposits)) + uassert.Equal(t, deposit.id, projectDeposits[0]) + + // Check depositsByUserProject + userProjectDeposits := state.DepositsByUserProject[testAddr][project.id] + uassert.Equal(t, 1, len(userProjectDeposits)) + uassert.Equal(t, deposit.id, userProjectDeposits[0]) +} + +func TestMultipleDeposits(t *testing.T) { + _, testAddr := setupTestDeposit(t) + project := createTestProject(t) + projects[project.id] = project + + state := &DepositState{ + Deposits: make(map[string]Deposit), + DepositsByProject: depositsByProject, + DepositsByUser: depositsByUser, + DepositsByUserProject: depositsByUserByProject, + } + + // Create multiple deposits + for i := 0; i < 3; i++ { + info := ProjectTierInfo{ + Project: project, + Tier: project.tiers[30], + TierType: "30", + CurrentHeight: uint64(std.GetHeight()), + CurrentTime: uint64(time.Now().Unix()), + } + + std.TestSetRealm(std.NewUserRealm(testAddr)) + deposit, _ := createDeposit(info, 100) + deposits[deposit.id] = deposit + updateDepositIndices(deposit, state) + + std.TestSkipHeights(100) // Skip some blocks between deposits + } + + // Verify indices are correct + userDeposits := state.DepositsByUser[testAddr] + uassert.Equal(t, 3, len(userDeposits)) + + projectDeposits := state.DepositsByProject[project.id]["30"] + uassert.Equal(t, 3, len(projectDeposits)) +} + +func TestUpdateDepositIndices_EmptyState(t *testing.T) { + _, testAddr := setupTestDeposit(t) + deposit := Deposit{ + id: "test_deposit", + projectId: "test_project", + tier: "30", + depositor: testAddr, + amount: 100, + } + + state := &DepositState{ + Deposits: make(map[string]Deposit), + DepositsByProject: make(map[string]map[string][]string), + DepositsByUser: make(map[std.Address][]string), + DepositsByUserProject: make(map[std.Address]map[string][]string), + } + + updateDepositIndices(deposit, state) + + uassert.NotEqual(t, nil, state.DepositsByUser) + uassert.NotEqual(t, nil, state.DepositsByProject) + uassert.NotEqual(t, nil, state.DepositsByUserProject) + + uassert.Equal(t, 1, len(state.DepositsByUser[testAddr])) + uassert.Equal(t, deposit.id, state.DepositsByUser[testAddr][0]) +} + +func TestIsProjectActive(t *testing.T) { + project := Project{ + started: TimeInfo{height: 100}, + ended: TimeInfo{height: 200}, + } + + tests := []struct { + name string + height uint64 + shouldBeActive bool + }{ + { + name: "Before start", + height: 50, + shouldBeActive: false, + }, + { + name: "At start", + height: 100, + shouldBeActive: true, + }, + { + name: "During active period", + height: 150, + shouldBeActive: true, + }, + { + name: "At end", + height: 200, + shouldBeActive: true, + }, + { + name: "After end", + height: 250, + shouldBeActive: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + active := isProjectActive(project, tt.height) + uassert.Equal(t, tt.shouldBeActive, active) + }) + } +} + +func TestIsTierActive(t *testing.T) { + tier := Tier{ + ended: TimeInfo{height: 200}, + } + + tests := []struct { + name string + height uint64 + shouldBeActive bool + }{ + { + name: "Before end", + height: 150, + shouldBeActive: true, + }, + { + name: "At end", + height: 200, + shouldBeActive: true, + }, + { + name: "After end", + height: 250, + shouldBeActive: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + active := isTierActive(tier, tt.height) + uassert.Equal(t, tt.shouldBeActive, active) + }) + } +} + +// Mock for GRC20 Token and Ledger +type MockPrivateLedger struct { + burnCalled bool + burnAmount uint64 + burnFrom std.Address + mintCalled bool + mintAmount uint64 + mintTo std.Address +} + +func (m *MockPrivateLedger) Burn(from std.Address, amount uint64) error { + m.burnCalled = true + m.burnFrom = from + m.burnAmount = amount + return nil +} + +func (m *MockPrivateLedger) Mint(to std.Address, amount uint64) error { + m.mintCalled = true + m.mintTo = to + m.mintAmount = amount + return nil +} + +type MockToken struct { + balances map[string]uint64 +} + +func NewMockToken() *MockToken { + return &MockToken{ + balances: make(map[string]uint64), + } +} + +func (m *MockToken) BalanceOf(addr std.Address) uint64 { + return m.balances[addr.String()] +} + +// Mock for emission +type MockEmission struct { + mintCalled bool +} + +func (m *MockEmission) MintAndDistributeGns() { + m.mintCalled = true +} + +type MockCommon struct { + haltCheckCalled bool +} + +func (m *MockCommon) IsHalted() { + m.haltCheckCalled = true +} + +func TestProcessCollectedDeposits_TransactionFail(t *testing.T) { + mockXGNS, _ := setupXGNSTest(t) + project := createTestProject(t) + projects[project.id] = project + + info := ProjectTierInfo{ + Project: project, + Tier: project.tiers[30], + TierType: "30", + CurrentHeight: uint64(std.GetHeight()), + CurrentTime: uint64(time.Now().Unix()), + } + + deposit, _ := createDeposit(info, 1000) + deposits[deposit.id] = deposit + + xgnsMint(t, mockXGNS, consts.LAUNCHPAD_ADDR, 500) + + std.TestSetRealm(std.NewCodeRealm(consts.LAUNCHPAD_PATH)) + + processCollectedDeposits([]string{deposit.id}, "") +} + +func TestProcessCollectedDeposits_BalanceChanges(t *testing.T) { + mockXGNS, _ := setupXGNSTest(t) + project := createTestProject(t) + projects[project.id] = project + + depositAmount := uint64(1000) + xgnsMint(t, mockXGNS, consts.LAUNCHPAD_ADDR, depositAmount) + + info := ProjectTierInfo{ + Project: project, + Tier: project.tiers[30], + TierType: "30", + CurrentHeight: uint64(std.GetHeight()), + CurrentTime: uint64(time.Now().Unix()), + } + + deposit, _ := createDeposit(info, depositAmount) + deposits[deposit.id] = deposit + + std.TestSkipHeights(int64(deposit.claimableHeight - uint64(std.GetHeight()) + 1)) + + std.TestSetRealm(std.NewCodeRealm(consts.LAUNCHPAD_PATH)) + amount := processCollectedDeposits([]string{deposit.id}, "") + + xgnsCheckTotalSupply(t, mockXGNS, 1000) +} + +func TestCollectDepositGnsByDepositId_ClaimableCheck(t *testing.T) { + _, depositor := setupTestDeposit(t) + std.TestSkipHeights(10) + project := createTestProject(t) + projects[project.id] = project + + // TierInfo + now := uint64(time.Now().Unix()) + height := uint64(std.GetHeight()) + info := ProjectTierInfo{ + Project: project, + Tier: project.tiers[30], + TierType: "30", + CurrentHeight: height, + CurrentTime: now, + } + + std.TestSkipHeights(10) + deposit, err := createDeposit(info, 100) + if err != nil { + t.Fatalf("failed create deposit: %v", err) + } + + updateDepositIndices(deposit, &DepositState{ + Deposits: deposits, + DepositsByProject: depositsByProject, + DepositsByUser: depositsByUser, + DepositsByUserProject: depositsByUserByProject, + }) + + t.Run("Collect early => Should panic", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(depositor)) + + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but did not occur") + } + }() + CollectDepositGnsByDepositId(deposit.id) + }) +} + +func TestCollectDepositGnsByDepositId_OtherOwner(t *testing.T) { + _, userA := setupTestDeposit(t) + userB := testutils.TestAddress("userB") + + project := createTestProject(t) + projects[project.id] = project + + // create deposit (owner = userA) + now := uint64(time.Now().Unix()) + height := uint64(std.GetHeight()) + info := ProjectTierInfo{ + Project: project, + Tier: project.tiers[30], + TierType: "30", + CurrentHeight: height, + CurrentTime: now, + } + + deposit, err := createDeposit(info, 100) + if err != nil { + t.Fatalf("failed to create deposit: %v", err) + } + + updateDepositIndices(deposit, &DepositState{ + Deposits: deposits, + DepositsByProject: depositsByProject, + DepositsByUser: depositsByUser, + DepositsByUserProject: depositsByUserByProject, + }) + + t.Run("UserB tries to collect userA's deposit => panic expected", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(userB)) + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but did not occur") + } else { + uassert.Equal(t, "[GNOSWAP-LAUNCHPAD-007] invalid input data || user has no deposits in any project", r.(string)) + } + }() + CollectDepositGnsByDepositId(deposit.id) + }) +} + +func TestCollectDepositGnsByDepositId_ExactClaimable(t *testing.T) { + _, userA := setupTestDeposit(t) + project := createTestProject(t) + projects[project.id] = project + + now := uint64(time.Now().Unix()) + initialHeight := uint64(std.GetHeight()) + + info := ProjectTierInfo{ + Project: project, + Tier: project.tiers[30], + TierType: "30", + CurrentHeight: initialHeight, + CurrentTime: now, + } + + std.TestSetRealm(std.NewUserRealm(userA)) + // deposit + deposit, err := createDeposit(info, 200) + if err != nil { + t.Fatalf("failed create deposit: %v", err) + } + deposit.claimableHeight = initialHeight + 10 + deposit.claimableTime = now + 60 + + updateDepositIndices(deposit, &DepositState{ + Deposits: deposits, + DepositsByProject: depositsByProject, + DepositsByUser: depositsByUser, + DepositsByUserProject: depositsByUserByProject, + }) + + std.TestSetRealm(std.NewCodeRealm(consts.EMISSION_PATH)) + std.TestSkipHeights(int64(10)) + gns.MintGns(consts.LAUNCHPAD_ADDR) + std.TestSkipHeights(int64(10)) + + std.TestSetRealm(std.NewUserRealm(userA)) + gns.Approve(consts.LAUNCHPAD_ADDR, 1000) + token, _ := setupXGNSTest(t) + + std.TestSetRealm(std.NewCodeRealm(consts.LAUNCHPAD_PATH)) + std.TestSetOrigCaller(std.Address(consts.LAUNCHPAD_PATH)) + xgnsMint(t, token, consts.LAUNCHPAD_ADDR, 1000) + + std.TestSetRealm(std.NewUserRealm(consts.LAUNCHPAD_ADDR)) + xgns.Mint( + consts.LAUNCHPAD_ADDR, + 1000, + ) + + skipCount := deposit.claimableHeight - uint64(std.GetHeight()) + if skipCount > 0 { + std.TestSkipHeights(int64(skipCount)) + } + + std.TestSetRealm(std.NewUserRealm(userA)) + collected := CollectDepositGnsByDepositId(deposit.id) + uassert.Equal(t, uint64(200), collected, "should collect the deposit fully after claimableHeight/time") +} diff --git a/contract/r/gnoswap/launchpad/doc.gno b/contract/r/gnoswap/launchpad/doc.gno new file mode 100644 index 000000000..8b953a31f --- /dev/null +++ b/contract/r/gnoswap/launchpad/doc.gno @@ -0,0 +1,3 @@ +// Package launchpad manages project creation and reward distribution, enabling users to +// participate with GNS and allowing projects to distribute tokens and receive funding via yield redirection. +package launchpad diff --git a/contract/r/gnoswap/launchpad/errors.gno b/contract/r/gnoswap/launchpad/errors.gno new file mode 100644 index 000000000..c61dba7fa --- /dev/null +++ b/contract/r/gnoswap/launchpad/errors.gno @@ -0,0 +1,44 @@ +package launchpad + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoLeftReward = errors.New("[GNOSWAP-LAUNCHPAD-001] no left reward") + errInvalidAddress = errors.New("[GNOSWAP-LAUNCHPAD-002] invalid address") + errDataNotFound = errors.New("[GNOSWAP-LAUNCHPAD-003] requested data not found") + errActiveProject = errors.New("[GNOSWAP-LAUNCHPAD-004] project is active") + errInactiveProject = errors.New("[GNOSWAP-LAUNCHPAD-005] project is inactive") + errInactiveTier = errors.New("[GNOSWAP-LAUNCHPAD-006] pool is inactive") + errInvalidInput = errors.New("[GNOSWAP-LAUNCHPAD-007] invalid input data") + errDuplicateProject = errors.New("[GNOSWAP-LAUNCHPAD-008] can not create same project in same block") + errInvalidTier = errors.New("[GNOSWAP-LAUNCHPAD-009] invalid pool") + errInvalidTierRatio = errors.New("[GNOSWAP-LAUNCHPAD-010] invalid pool ratio") + errInvalidLength = errors.New("[GNOSWAP-LAUNCHPAD-011] invalid length") + errNotEnoughBalance = errors.New("[GNOSWAP-LAUNCHPAD-012] not enough balance") + errInvalidCondition = errors.New("[GNOSWAP-LAUNCHPAD-013] invalid transfer condition") + errConvertFail = errors.New("[GNOSWAP-LAUNCHPAD-014] convert fail") + errNotUserCaller = errors.New("[GNOSWAP-LAUNCHPAD-015] only user caller") + errInvalidData = errors.New("[GNOSWAP-LAUNCHPAD-016] invalid data") + errInvalidAmount = errors.New("[GNOSWAP-LAUNCHPAD-017] invalid amount") + errDuplicateDeposit = errors.New("[GNOSWAP-LAUNCHPAD-018] duplicate deposit") + errInvalidRewardState = errors.New("[GNOSWAP-LAUNCHPAD-019] invalid reward state") + errNotExistDeposit = errors.New("[GNOSWAP-LAUNCHPAD-020] not exist deposit") + errAlreadyExistDeposit = errors.New("[GNOSWAP-LAUNCHPAD-021] already exist deposit") + errInvalidProjectId = errors.New("[GNOSWAP-LAUNCHPAD-022] invalid project id") + errAlreadyCollected = errors.New("[GNOSWAP-LAUNCHPAD-023] already collected") + errNotYetClaimReward = errors.New("[GNOSWAP-LAUNCHPAD-024] not yet claim reward") + errInvalidCaller = errors.New("[GNOSWAP-LAUNCHPAD-025] invalid caller") + errInvalidOwner = errors.New("[GNOSWAP-LAUNCHPAD-026] invalid owner") + errInvalidAvgBlockTime = errors.New("[GNOSWAP-LAUNCHPAD-027] invalid average block time") + errInvalidHeight = errors.New("[GNOSWAP-LAUNCHPAD-028] invalid height") + errTierHasParticipants = errors.New("[GNOSWAP-LAUNCHPAD-029] tier has participants") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/launchpad/gno.mod b/contract/r/gnoswap/launchpad/gno.mod new file mode 100644 index 000000000..15bfa1540 --- /dev/null +++ b/contract/r/gnoswap/launchpad/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/launchpad diff --git a/contract/r/gnoswap/launchpad/json_builder.gno b/contract/r/gnoswap/launchpad/json_builder.gno new file mode 100644 index 000000000..e674aa939 --- /dev/null +++ b/contract/r/gnoswap/launchpad/json_builder.gno @@ -0,0 +1,118 @@ +package launchpad + +import ( + "std" + "time" + + "gno.land/p/demo/json" + "gno.land/p/demo/ufmt" +) + +// TimeInfoBuilder adds TimeInfo fields to JSON +func TimeInfoBuilder(b *json.NodeBuilder, prefix string, info TimeInfo) *json.NodeBuilder { + return b. + WriteString(prefix+"Height", ufmt.Sprintf("%d", info.height)). + WriteString(prefix+"Time", ufmt.Sprintf("%d", info.time)) +} + +// ProjectStatsBuilder adds ProjectStats fields to JSON +func ProjectStatsBuilder(b *json.NodeBuilder, stats ProjectStats) *json.NodeBuilder { + return b. + WriteString("totalDeposit", ufmt.Sprintf("%d", stats.totalDeposit)). + WriteString("actualDeposit", ufmt.Sprintf("%d", stats.actualDeposit)). + WriteString("totalParticipant", ufmt.Sprintf("%d", stats.totalParticipant)). + WriteString("actualParticipant", ufmt.Sprintf("%d", stats.actualParticipant)). + WriteString("totalCollected", ufmt.Sprintf("%d", stats.totalCollected)) +} + +// RefundInfoBuilder adds RefundInfo fields to JSON +func RefundInfoBuilder(b *json.NodeBuilder, info RefundInfo) *json.NodeBuilder { + return b. + WriteString("refundedAmount", ufmt.Sprintf("%d", info.amount)). + WriteString("refundedHeight", ufmt.Sprintf("%d", info.height)). + WriteString("refundedTime", ufmt.Sprintf("%d", info.time)) +} + +// TierBuilder adds Tier fields to JSON +func TierBuilder(b *json.NodeBuilder, prefix string, tier Tier) *json.NodeBuilder { + b.WriteString(prefix+"CollectWaitDuration", ufmt.Sprintf("%d", tier.collectWaitDuration)) + b.WriteString(prefix+"TierAmount", ufmt.Sprintf("%d", tier.tierAmount)) + if tier.tierAmountPerBlockX128 != nil { + b.WriteString(prefix+"TierAmountPerBlockX128", tier.tierAmountPerBlockX128.ToString()) + } + TimeInfoBuilder(b, prefix+"Started", tier.started) + TimeInfoBuilder(b, prefix+"Ended", tier.ended) + b.WriteString(prefix+"TotalDepositAmount", ufmt.Sprintf("%d", tier.totalDepositAmount)) + b.WriteString(prefix+"ActualDepositAmount", ufmt.Sprintf("%d", tier.actualDepositAmount)) + b.WriteString(prefix+"TotalParticipant", ufmt.Sprintf("%d", tier.totalParticipant)) + b.WriteString(prefix+"ActualParticipant", ufmt.Sprintf("%d", tier.actualParticipant)) + b.WriteString(prefix+"UserCollectedAmount", ufmt.Sprintf("%d", tier.userCollectedAmount)) + b.WriteString(prefix+"CalculatedAmount", ufmt.Sprintf("%d", tier.calculatedAmount)) + return b +} + +// ProjectBuilder adds Project fields to JSON +func ProjectBuilder(b *json.NodeBuilder, project Project) *json.NodeBuilder { + b.WriteString("name", project.name) + b.WriteString("tokenPath", project.tokenPath) + b.WriteString("depositAmount", ufmt.Sprintf("%d", project.depositAmount)) + b.WriteString("recipient", project.recipient.String()) + + tokenPathList, amountList := makeConditionsToStr(project.conditions) + b.WriteString("conditionsToken", tokenPathList) + b.WriteString("conditionsAmount", amountList) + + // Add tier ratios + for _, duration := range []uint64{30, 90, 180} { + b.WriteString(ufmt.Sprintf("tier%dRatio", duration), + ufmt.Sprintf("%d", project.tiersRatios[duration])) + } + + // Add time info + TimeInfoBuilder(b, "created", project.created) + TimeInfoBuilder(b, "started", project.started) + TimeInfoBuilder(b, "ended", project.ended) + + // Add stats + ProjectStatsBuilder(b, project.stats) + + // Add refund info + RefundInfoBuilder(b, project.refund) + + // Add tiers info + for _, duration := range []uint64{30, 90, 180} { + if tier, exists := project.tiers[duration]; exists { + TierBuilder(b, ufmt.Sprintf("tier%d", duration), tier) + } + } + + return b +} + +// DepositBuilder adds Deposit fields to JSON +func DepositBuilder(b *json.NodeBuilder, deposit Deposit) *json.NodeBuilder { + return b. + WriteString("projectId", deposit.projectId). + WriteString("tier", deposit.tier). + WriteString("depositor", deposit.depositor.String()). + WriteString("amount", ufmt.Sprintf("%d", deposit.amount)). + WriteString("depositHeight", ufmt.Sprintf("%d", deposit.depositHeight)). + WriteString("depositTime", ufmt.Sprintf("%d", deposit.depositTime)). + WriteString("depositCollectHeight", ufmt.Sprintf("%d", deposit.depositCollectHeight)). + WriteString("depositCollectTime", ufmt.Sprintf("%d", deposit.depositCollectTime)). + WriteString("claimableHeight", ufmt.Sprintf("%d", deposit.claimableHeight)). + WriteString("claimableTime", ufmt.Sprintf("%d", deposit.claimableTime)). + WriteString("rewardCollected", ufmt.Sprintf("%d", deposit.rewardCollected)). + WriteString("rewardCollectHeight", ufmt.Sprintf("%d", deposit.rewardCollectHeight)). + WriteString("rewardCollectTime", ufmt.Sprintf("%d", deposit.rewardCollectTime)) +} + +// MetaBuilder adds metadata fields to JSON +func MetaBuilder() *json.NodeBuilder { + height := std.GetHeight() + now := time.Now().Unix() + + return json.Builder(). + WriteString("height", ufmt.Sprintf("%d", height)). + WriteString("now", ufmt.Sprintf("%d", now)) +} diff --git a/contract/r/gnoswap/launchpad/json_builder_test.gno b/contract/r/gnoswap/launchpad/json_builder_test.gno new file mode 100644 index 000000000..d6fefef5d --- /dev/null +++ b/contract/r/gnoswap/launchpad/json_builder_test.gno @@ -0,0 +1,117 @@ +package launchpad + +import ( + "gno.land/p/demo/json" + u256 "gno.land/p/gnoswap/uint256" + "std" + "testing" + + "gno.land/p/demo/uassert" +) + +func TestTimeInfoBuilder(t *testing.T) { + info := TimeInfo{ + height: 100, + time: 1234567890, + } + + b := json.Builder() + TimeInfoBuilder(b, "test", info) + + result := b.Node() + expected := `{"testHeight":"100","testTime":"1234567890"}` + + res, err := json.Marshal(result) + uassert.NoError(t, err) + uassert.Equal(t, string(res), expected) +} + +func TestProjectStatsBuilder(t *testing.T) { + stats := ProjectStats{ + totalDeposit: 1000, + actualDeposit: 800, + totalParticipant: 50, + actualParticipant: 40, + totalCollected: 200, + } + + b := json.Builder() + ProjectStatsBuilder(b, stats) + + result := b.Node() + expected := `{"totalDeposit":"1000","actualDeposit":"800","totalParticipant":"50","actualParticipant":"40","totalCollected":"200"}` + + res, err := json.Marshal(result) + uassert.NoError(t, err) + uassert.Equal(t, string(res), expected) +} + +func TestTierBuilder(t *testing.T) { + amountPerBlock := u256.MustFromDecimal("1000000") + tier := Tier{ + collectWaitDuration: 100, + tierAmount: 5000, + tierAmountPerBlockX128: amountPerBlock, + started: TimeInfo{height: 100, time: 1234567890}, + ended: TimeInfo{height: 200, time: 1234567899}, + totalDepositAmount: 1000, + actualDepositAmount: 800, + totalParticipant: 50, + actualParticipant: 40, + userCollectedAmount: 300, + calculatedAmount: 400, + } + + b := json.Builder() + TierBuilder(b, "test", tier) + + result := b.Node() + expected := `{"testCollectWaitDuration":"100","testTierAmount":"5000","testTierAmountPerBlockX128":"1000000","testStartedHeight":"100","testStartedTime":"1234567890","testEndedHeight":"200","testEndedTime":"1234567899","testTotalDepositAmount":"1000","testActualDepositAmount":"800","testTotalParticipant":"50","testActualParticipant":"40","testUserCollectedAmount":"300","testCalculatedAmount":"400"}` + + res, err := json.Marshal(result) + uassert.NoError(t, err) + uassert.Equal(t, string(res), expected) +} + +func TestProjectBuilder(t *testing.T) { + project := Project{ + name: "Test Project", + tokenPath: "gno.land/r/test", + depositAmount: 1000, + recipient: std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5"), + conditions: map[string]Condition{ + "token1": {tokenPath: "path1", minAmount: 100}, + }, + tiersRatios: map[uint64]uint64{ + 30: 20, + 90: 30, + 180: 50, + }, + created: TimeInfo{height: 100, time: 1234567890}, + started: TimeInfo{height: 110, time: 1234567899}, + ended: TimeInfo{height: 200, time: 1234568000}, + stats: ProjectStats{ + totalDeposit: 1000, + actualDeposit: 800, + totalParticipant: 50, + actualParticipant: 40, + totalCollected: 200, + }, + refund: RefundInfo{ + amount: 100, + height: 150, + time: 1234567950, + }, + } + + b := json.Builder() + ProjectBuilder(b, project) + + data := b.Node() + + result, err := json.Marshal(data) + uassert.NoError(t, err) + + expected := `{"name":"Test Project","tokenPath":"gno.land/r/test","depositAmount":"1000","recipient":"g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5","conditionsToken":"token1","conditionsAmount":"100","tier30Ratio":"20","tier90Ratio":"30","tier180Ratio":"50","createdHeight":"100","createdTime":"1234567890","startedHeight":"110","startedTime":"1234567899","endedHeight":"200","endedTime":"1234568000","totalDeposit":"1000","actualDeposit":"800","totalParticipant":"50","actualParticipant":"40","totalCollected":"200","refundedAmount":"100","refundedHeight":"150","refundedTime":"1234567950"}` + uassert.Equal(t, string(result), expected) +} diff --git a/contract/r/gnoswap/launchpad/launchpad.gno b/contract/r/gnoswap/launchpad/launchpad.gno new file mode 100644 index 000000000..79daeb476 --- /dev/null +++ b/contract/r/gnoswap/launchpad/launchpad.gno @@ -0,0 +1,711 @@ +package launchpad + +import ( + "errors" + "std" + "strconv" + "strings" + "time" + + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + en "gno.land/r/gnoswap/v1/emission" + "gno.land/r/gnoswap/v1/gns" +) + +var ( + projects = make(map[string]Project) // projectId -> project + + // project tier should distribute project token if deposit ever happened + // therefore we need to keep track of project tiers without deposit + projectTiersWithoutDeposit = make(map[string]bool) // tierId -> true + + q128 = u256.Zero() +) + +func init() { + q128 = u256.MustFromDecimal(consts.Q128) +} + +// ProjectInput defines the input parameters required to create a project. +// +// Fields: +// - Name (string): The name of the project. +// - TokenPath (string): The token path for the deposit token. +// - Recipient (std.Address): The recipient address for the project's rewards. +// - DepositAmount (uint64): The total deposit amount for the project. +// - ConditionsToken (string): A string containing token paths separated by `*PAD*`. +// - ConditionsAmount (string): A string containing corresponding amounts separated by `*PAD*`. +// - Tier30Ratio (uint64): The percentage of the deposit allocated to the 30-day tier. +// - Tier90Ratio (uint64): The percentage of the deposit allocated to the 90-day tier. +// - Tier180Ratio (uint64): The percentage of the deposit allocated to the 180-day tier. +// - StartTime (uint64): The start time of the project in Unix timestamp. +type ProjectInput struct { + Name string + TokenPath string + Recipient std.Address + DepositAmount uint64 + ConditionsToken string // separated by PAD_SEP (`*PAD*`) + ConditionsAmount string // separated by PAD_SEP (`*PAD*`) + Tier30Ratio uint64 + Tier90Ratio uint64 + Tier180Ratio uint64 + StartTime uint64 +} + +// NewProjectInput creates and returns a new `ProjectInput` object. +// +// Parameters: +// - name (string): The name of the project. +// - tokenPath (string): The token path for the deposit token. +// - recipient (std.Address): The recipient address for the project's rewards. +// - depositAmount (uint64): The total deposit amount for the project. +// - conditionsToken (string): Token paths separated by `*PAD*`. +// - conditionsAmount (string): Corresponding amounts separated by `*PAD*`. +// - tier30Ratio (uint64): Allocation percentage for the 30-day tier. +// - tier90Ratio (uint64): Allocation percentage for the 90-day tier. +// - tier180Ratio (uint64): Allocation percentage for the 180-day tier. +// - startTime (uint64): The project's start time in Unix timestamp. +// +// Returns: +// - ProjectInput: The initialized `ProjectInput` object. +func NewProjectInput( + name string, + tokenPath string, + recipient std.Address, + depositAmount uint64, + conditionsToken string, + conditionsAmount string, + tier30Ratio uint64, + tier90Ratio uint64, + tier180Ratio uint64, + startTime uint64, +) ProjectInput { + return ProjectInput{ + Name: name, + TokenPath: tokenPath, + Recipient: recipient, + DepositAmount: depositAmount, + ConditionsToken: conditionsToken, + ConditionsAmount: conditionsAmount, + Tier30Ratio: tier30Ratio, + Tier90Ratio: tier90Ratio, + Tier180Ratio: tier180Ratio, + StartTime: startTime, + } +} + +// ValidateName checks if the project name is valid. +// +// Returns: +// - error: If the name is empty, returns an error. +func (in ProjectInput) ValidateName() error { + if in.Name == "" { + return errors.New(addDetailToError( + errInvalidInput, "project name cannot be empty")) + } + if len(in.Name) > 100 { + return errors.New(addDetailToError(errInvalidInput, "project name is too long")) + } + return nil +} + +// ValidateTokenPath validates the token path. +// +// Ensures that the token path is not empty and is registered. +// +// Returns: +// - error: If the token path is invalid or unregistered, returns an error. +func (in ProjectInput) ValidateTokenPath() error { + if in.TokenPath == "" { + return errors.New(addDetailToError( + errInvalidInput, "tokenPath cannot be empty")) + } + if err := common.IsRegistered(in.TokenPath); err != nil { + return errors.New(addDetailToError( + errInvalidInput, ufmt.Sprintf("tokenPath(%s) not registered", in.TokenPath))) + } + return nil +} + +// ValidateRecipient checks if the recipient address is valid. +// +// Returns: +// - error: If the recipient address is invalid, returns an error. +func (in ProjectInput) ValidateRecipient() error { + if !in.Recipient.IsValid() { + return errors.New(addDetailToError( + errInvalidAddress, ufmt.Sprintf("recipient address(%s)", in.Recipient.String()))) + } + return nil +} + +// ValidateDepositAmount ensures that the deposit amount is greater than zero. +// +// Returns: +// - error: If the deposit amount is zero, returns an error. +func (in ProjectInput) ValidateDepositAmount() error { + if in.DepositAmount == 0 { + return errors.New(addDetailToError( + errInvalidInput, "deposit amount cannot be 0")) + } + return nil +} + +// ValidateRatio checks if the sum of the tier ratios equals 100. +// +// Returns: +// - error: If the sum of `Tier30Ratio`, `Tier90Ratio`, and `Tier180Ratio` is not 100, returns an error. +func (in ProjectInput) ValidateRatio() error { + sum := in.Tier30Ratio + in.Tier90Ratio + in.Tier180Ratio + if sum != 100 { + return errors.New(addDetailToError( + errInvalidInput, ufmt.Sprintf("invalid ratio, sum of all tiers(30:%d, 90:%d, 180:%d) should be 100", + in.Tier30Ratio, in.Tier90Ratio, in.Tier180Ratio))) + } + return nil +} + +// ValidateStartTime checks if the start time is in the future. +// +// Parameters: +// - now (uint64): The current time in Unix timestamp. +// +// Returns: +// - error: If the start time is in the past or equal to `now`, returns an error. +func (in ProjectInput) ValidateStartTime(now uint64) error { + if in.StartTime <= now { + return errors.New(addDetailToError( + errInvalidInput, ufmt.Sprintf("start time(%d) must be greater than now(%d)", in.StartTime, now))) + } + return nil +} + +// ParseConditions parses the conditions from `ConditionsToken` and `ConditionsAmount`. +// +// Splits the token and amount strings using `*PAD*`, validates each token path, and converts +// the amounts to integers. Returns a map of token paths to their corresponding conditions. +// +// Returns: +// - map[string]Condition: A map of token paths to `Condition` objects. +// - error: If the tokens and amounts are mismatched, invalid, or unregistered, returns an error. +func (in ProjectInput) ParseConditions() (map[string]Condition, error) { + if in.ConditionsToken == "" || in.ConditionsAmount == "" { + return nil, errors.New(addDetailToError( + errInvalidInput, "conditionsToken or conditionsAmount cannot be empty")) + } + + tokens := strings.Split(in.ConditionsToken, PAD_SEP) + amounts := strings.Split(in.ConditionsAmount, PAD_SEP) + if len(tokens) != len(amounts) { + return nil, errors.New(addDetailToError( + errInvalidLength, ufmt.Sprintf("invalid conditions(numTokens(%d) != numAmounts(%d))", len(tokens), len(amounts)))) + } + + condition := make(map[string]Condition) + for i, token := range tokens { + if token == "" { + return nil, errors.New(addDetailToError( + errInvalidInput, ufmt.Sprintf("invalid token(%s)", token))) + } + if err := common.IsRegistered(token); err != nil { + return nil, errors.New(addDetailToError( + errInvalidInput, ufmt.Sprintf("token(%s) not registered", token))) + } + if _, exists := condition[token]; exists { + return nil, errors.New(addDetailToError( + errInvalidInput, ufmt.Sprintf("duplicated condition token(%s)", token))) + } + + amtStr := amounts[i] + amtVal, err := strconv.ParseUint(amtStr, 10, 64) + if err != nil { + return nil, errors.New(addDetailToError( + errConvertFail, err.Error())) + } + + condition[token] = Condition{ + tokenPath: token, + minAmount: amtVal, + } + } + return condition, nil +} + +// ValidateAll validates all the fields of the `ProjectInput`. +// +// Calls individual validation methods for the name, token path, recipient, deposit amount, +// tier ratios, and start time. Also parses the conditions. +// +// Parameters: +// - now (uint64): The current time in Unix timestamp. +// +// Returns: +// - map[string]Condition: The parsed conditions map. +// - error: If any validation fails, returns an error. +func (in ProjectInput) ValidateAll(now uint64) (map[string]Condition, error) { + if err := in.ValidateName(); err != nil { + return nil, err + } + if err := in.ValidateTokenPath(); err != nil { + return nil, err + } + if err := in.ValidateRecipient(); err != nil { + return nil, err + } + if err := in.ValidateDepositAmount(); err != nil { + return nil, err + } + if err := in.ValidateRatio(); err != nil { + return nil, err + } + if err := in.ValidateStartTime(now); err != nil { + return nil, err + } + + condMap, err := in.ParseConditions() + if err != nil { + return nil, err + } + return condMap, nil +} + +// ConvertStartTimeToHeight converts the project's start time to a block height. +// +// Calculates the block height at which the project should start based on the average block time +// and the difference between the start time and the current time. +// +// Parameters: +// - now (uint64): The current time in Unix timestamp. +// +// Returns: +// - uint64: The calculated block height for the project's start. +// - error: If the average block time is unavailable or the start time is in the past, returns an error. +func (in ProjectInput) ConvertStartTimeToHeight(now uint64) (uint64, error) { + avgBlockTimeMs := uint64(gns.GetAvgBlockTimeInMs()) + if avgBlockTimeMs <= 0 { + panic(addDetailToError( + errInvalidInput, "average block time not available")) + } + + diffSec := int64(in.StartTime) - int64(now) + if diffSec <= 0 { + return 0, errors.New(addDetailToError( + errInvalidInput, ufmt.Sprintf("invalid start time(%d), cannot start project in past(now:%d)", + in.StartTime, now))) + } + + diffBlockHeight := uint64(diffSec * 1000 / int64(avgBlockTimeMs)) + currentHeight := uint64(std.GetHeight()) + return currentHeight + diffBlockHeight, nil +} + +// ProjectCalculationResult middle result of project params calculation +type ProjectCalculationResult struct { + Tier30Amount uint64 + Tier90Amount uint64 + Tier180Amount uint64 + StartHeight uint64 +} + +// CreateProject creates a new project with specified tiers, conditions, and token deposit. +// +// This function initializes a project by validating the input, transferring the deposit amount, +// and creating the project's tiers. The project is stored globally and an event is emitted. +// +// Parameters: +// - name (string): The name of the project. +// - tokenPath (string): The token path for the deposit token. +// - recipient (std.Address): The address to receive the project's rewards. +// - depositAmount (uint64): The total amount of tokens to be deposited for the project. +// - conditionsToken (string): The token used for reward conditions. +// - conditionsAmount (string): The amount of the conditions token required. +// - tier30Ratio (uint64): The percentage of the deposit allocated to the 30-day tier. +// - tier90Ratio (uint64): The percentage of the deposit allocated to the 90-day tier. +// - tier180Ratio (uint64): The percentage of the deposit allocated to the 180-day tier. +// - startTime (uint64): The project's start time in Unix timestamp (seconds). +// +// Returns: +// - string: The unique ID of the created project. +func CreateProject( + name string, + tokenPath string, + recipient std.Address, + depositAmount uint64, + conditionsToken string, + conditionsAmount string, + tier30Ratio uint64, + tier90Ratio uint64, + tier180Ratio uint64, + startTime uint64, +) string { + assertOnlyNotHalted() + assertOnlyAdmin() + + input := NewProjectInput(name, tokenPath, recipient, depositAmount, conditionsToken, conditionsAmount, tier30Ratio, tier90Ratio, tier180Ratio, startTime) + now := uint64(time.Now().Unix()) + conditions, err := input.ValidateAll(now) + if err != nil { + panic(err.Error()) + } + + en.MintAndDistributeGns() + + projectId := generateProjectId(input.TokenPath) + if _, exists := projects[projectId]; exists { + panic(addDetailToError( + errDuplicateProject, + ufmt.Sprintf("project(%s) already exists", projectId))) + } + + calcResult, err := calculateProjectParams(input, now) + if err != nil { + panic(addDetailToError(errInvalidInput, err.Error())) + } + + startHeight := calcResult.StartHeight + tier30 := createTier(projectId, 30, calcResult.Tier30Amount, startHeight, input.StartTime, convertTimeToHeight(TIMESTAMP_3DAYS)) + tier90 := createTier(projectId, 90, calcResult.Tier90Amount, startHeight, input.StartTime, convertTimeToHeight(TIMESTAMP_7DAYS)) + tier180 := createTier(projectId, 180, calcResult.Tier180Amount, startHeight, input.StartTime, convertTimeToHeight(TIMESTAMP_14DAYS)) + tierMap := map[uint64]Tier{ + TIER30: tier30, + TIER90: tier90, + TIER180: tier180, + } + + tierRatioMap := map[uint64]uint64{ + TIER30: input.Tier30Ratio, + TIER90: input.Tier90Ratio, + TIER180: input.Tier180Ratio, + } + + projectCreated := TimeInfo{ + height: uint64(std.GetHeight()), + time: now, + } + + projectStarted := TimeInfo{ + height: startHeight, + time: input.StartTime, + } + + projectEnded := TimeInfo{ + height: tier180.ended.height, + time: tier180.ended.time, + } + + projectStats := NewProjectStats(0, 0, 0, 0, 0) + + refundInfo := NewRefundInfo(0, 0, 0) + + project := NewProject( + projectId, + name, + tokenPath, + depositAmount, + recipient, + conditions, + tierMap, + tierRatioMap, + projectCreated, + projectStarted, + projectEnded, + *projectStats, + refundInfo, + ) + + projects[projectId] = *project + + // Transfer tokens + tokenTeller := common.GetTokenTeller(tokenPath) + tokenTeller.TransferFrom( + getPrevAddr(), + std.Address(GetOrigPkgAddr()), + depositAmount, + ) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "CreateProject", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "name", name, + "tokenPath", tokenPath, + "recipient", recipient.String(), + "depositAmount", formatUint(depositAmount), + "conditionsToken", conditionsToken, + "conditionsAmount", conditionsAmount, + "tier30Ratio", formatUint(tier30Ratio), + "tier90Ratio", formatUint(tier90Ratio), + "tier180Ratio", formatUint(tier180Ratio), + "startTime", formatUint(startTime), + "startHeight", formatUint(startHeight), + "projectId", projectId, + "tier30Amount", formatUint(calcResult.Tier30Amount), + "tier30EndHeight", formatUint(tier30.ended.height), + "tier90Amount", formatUint(calcResult.Tier90Amount), + "tier90EndHeight", formatUint(tier90.ended.height), + "tier180Amount", formatUint(calcResult.Tier180Amount), + "tier180EndHeight", formatUint(tier180.ended.height), + ) + + return projectId +} + +// TransferLeftFromProjectByAdmin transfers the remaining rewards of a project to a specified recipient. +// +// This function is called by an admin to transfer any unclaimed rewards from a project to a recipient address. +// It validates the project ID, checks the recipient conditions, calculates the remaining rewards, and performs the transfer. +// +// Parameters: +// - projectId (string): The unique identifier of the project. +// - recipient (std.Address): The recipient address to transfer the remaining rewards. +// +// Returns: +// - uint64: The amount of rewards transferred to the recipient. +func TransferLeftFromProjectByAdmin(projectId string, recipient std.Address) uint64 { + if err := common.SatisfyCond(isUserCall()); err != nil { + panic(addDetailToError(errNotUserCaller, err.Error())) + } + assertOnlyAdmin() + assertOnlyNotHalted() + + project, err := getProject(projectId) + if err != nil { + panic(err.Error()) + } + + currentHeight := uint64(std.GetHeight()) + if err := validateTransferLeft(project, recipient, currentHeight); err != nil { + panic(addDetailToError( + errInvalidCondition, err.Error())) + } + + en.MintAndDistributeGns() + + accumTotalReward := uint64(0) + accumLeftReward := uint64(0) + accumCollectedReward := uint64(0) + tierMap := project.Tiers() + for _, tier := range tierMap { + if !tier.isEnded(currentHeight) { + panic(addDetailToError( + errActiveProject, ufmt.Sprintf("tier(%d) is not ended", tier.ID()))) + } + if tier.ActualParticipant() > 0 { + panic(addDetailToError( + errTierHasParticipants, ufmt.Sprintf("tier(%d) has (%d) participants", tier.ID(), tier.ActualParticipant()))) + } + leftReward := tier.calculateLeftReward() + accumLeftReward += leftReward + accumCollectedReward += tier.UserCollectedAmount() + accumTotalReward += tier.TierAmount() + } + + if accumTotalReward != accumCollectedReward+accumLeftReward { + panic(addDetailToError( + errInvalidRewardState, ufmt.Sprintf("accumTotalReward(%d) != accumCollectedReward(%d)+accumLeftReward(%d)", accumTotalReward, accumCollectedReward, accumLeftReward))) + } + + projectLeftReward := project.calculateLeftReward() + project.refund = RefundInfo{ + amount: projectLeftReward, + height: currentHeight, + time: uint64(time.Now().Unix()), + } + + projects[projectId] = project + + if projectLeftReward > 0 { + tokenTeller := common.GetTokenTeller(project.tokenPath) + tokenTeller.Transfer(recipient, projectLeftReward) + } + + std.Emit( + "TransferLeftFromProjectByAdmin", + "projectId", projectId, + "recipient", recipient.String(), + "tokenPath", project.tokenPath, + "leftReward", formatUint(projectLeftReward), + "tier30Full", formatUint(project.tiers[30].tierAmount), + "tier30Left", formatUint(project.tiers[30].tierAmount-project.tiers[30].calculatedAmount), + "tier90Full", formatUint(project.tiers[90].tierAmount), + "tier90Left", formatUint(project.tiers[90].tierAmount-project.tiers[90].calculatedAmount), + "tier180Full", formatUint(project.tiers[180].tierAmount), + "tier180Left", formatUint(project.tiers[180].tierAmount-project.tiers[180].calculatedAmount), + ) + + return projectLeftReward +} + +// calculateProjectParams calculates project-related parameters based on the input and current time. +// +// This function computes the amounts for three tiers (30, 90, and 180), as well as the block height +// when the project starts based on the provided start time. +// +// Parameters: +// - input (ProjectInput): The input parameters for the project. +// - DepositAmount (uint64): The total deposit amount. +// - Tier30Ratio (uint64): The percentage of the deposit allocated to the 30-day tier. +// - Tier90Ratio (uint64): The percentage of the deposit allocated to the 90-day tier. +// - Tier180Ratio (uint64): The percentage of the deposit allocated to the 180-day tier. +// - StartTime (uint64): The project start time in Unix timestamp (seconds). +// +// - now (uint64): The current time in Unix timestamp (seconds). +// +// Returns: +// - *ProjectCalculationResult: A pointer to the calculated results, including tier amounts and start height. +// - error: Returns an error if the start time is earlier than the current time or if the average block time is unavailable. +func calculateProjectParams(input ProjectInput, now uint64) (*ProjectCalculationResult, error) { + if input.StartTime <= now { + return nil, errors.New(addDetailToError( + errInvalidInput, ufmt.Sprintf("start time(%d) must be greater than now(%d)", input.StartTime, now))) + } + + tier30Amount := input.DepositAmount * input.Tier30Ratio / 100 + tier90Amount := input.DepositAmount * input.Tier90Ratio / 100 + tier180Amount := input.DepositAmount * input.Tier180Ratio / 100 + + // calculate when the block starts + timeUntilStart := input.StartTime - now + blockDurationToStart := convertTimeToHeight(timeUntilStart) + currentHeight := uint64(std.GetHeight()) + startHeight := currentHeight + blockDurationToStart + + return &ProjectCalculationResult{ + Tier30Amount: tier30Amount, + Tier90Amount: tier90Amount, + Tier180Amount: tier180Amount, + StartHeight: startHeight, + }, nil +} + +// generateProjectId generates a unique project ID based on the given token path and the current block height. +// +// The generated ID combines the `tokenPath` and the current block height in the following format: +// "{tokenPath}:{height}" +// +// Parameters: +// - tokenPath (string): The path of the token associated with the project. +// +// Returns: +// - string: A unique project ID in the format "tokenPath:height". +func generateProjectId(tokenPath string) string { + // gno.land/r/gnoswap/gns:{height} + // gno.land/r/gnoswap/gns:30 + return ufmt.Sprintf("%s:%d", tokenPath, std.GetHeight()) +} + +// generateTierId generates a unique tier ID based on the given project ID and the tier duration. +// +// The generated ID combines the `projectId` and the `duration` in the following format: +// "{projectId}:{duration}" +// +// Parameters: +// - projectId (string): The unique ID of the project associated with the tier. +// - duration (uint64): The duration of the tier (e.g., 30, 90, 180 days). +// +// Returns: +// - string: A unique tier ID in the format "projectId:duration". +func generateTierId(projectId string, duration uint64) string { + // gno.land/r/gnoswap/gns:{height}:{duration} + // gno.land/r/gnoswap/gns:30:30(90,180) + return ufmt.Sprintf("%s:%d", projectId, duration) +} + +// createTier creates a new tier with the given parameters +func createTier( + projectId string, + duration uint64, + amount uint64, + startHeight uint64, + startTime uint64, + collectWaitDuration uint64) Tier { + + tierId := generateTierId(projectId, duration) + + tierStart := TimeInfo{ + height: startHeight, + time: startTime, + } + + durationSecond := duration * TIMESTAMP_DAY + endTime := startTime + durationSecond + durationHeight := convertTimeToHeight(durationSecond) + endHeight := startHeight + durationHeight + tierEnded := TimeInfo{ + height: endHeight, + time: endTime, + } + + rewardPerBlockX128 := calcRewardPerBlockX128(amount, durationHeight) + + // reward + reward := NewReward(u256.Zero(), startHeight, endHeight) + + tier := Tier{ + id: tierId, + collectWaitDuration: collectWaitDuration, + tierAmount: amount, + tierAmountPerBlockX128: rewardPerBlockX128, + started: tierStart, + ended: tierEnded, + totalDepositAmount: 0, + actualDepositAmount: 0, + totalParticipant: 0, + actualParticipant: 0, + userCollectedAmount: 0, + calculatedAmount: 0, + reward: *reward, + } + + rewardState := NewRewardState(tier.tierAmountPerBlockX128, startHeight, endHeight) + rewardStates.Set(projectId, strconv.FormatUint(duration, 10), rewardState) + + projectTiersWithoutDeposit[tierId] = true + return tier +} + +// validateTransferLeft validates the transfer of remaining tokens +func validateTransferLeft(project Project, recipient std.Address, height uint64) error { + if !recipient.IsValid() { + return errors.New(ufmt.Sprintf("invalid recipient address(%s)", recipient.String())) + } + + if height < project.Ended().height { + return errors.New(ufmt.Sprintf("project not ended yet(current:%d, endHeight: %d)", + height, project.ended.height)) + } + + if project.Refund().height != 0 { + return errors.New(ufmt.Sprintf("project already refunded(height:%d)", + project.Refund().height)) + } + + return nil +} + +// calculateLeftReward calculates the remaining reward amount for each tier +func calculateLeftReward(project Project) uint64 { + tier30 := project.tiers[30] + tier90 := project.tiers[90] + tier180 := project.tiers[180] + + left30 := tier30.tierAmount - tier30.calculatedAmount + left90 := tier90.tierAmount - tier90.calculatedAmount + left180 := tier180.tierAmount - tier180.calculatedAmount + return left30 + left90 + left180 +} + +// getProject returns the project with the given project ID +func getProject(projectId string) (Project, error) { + project, exists := projects[projectId] + if !exists { + return Project{}, errors.New(addDetailToError( + errDataNotFound, ufmt.Sprintf("projectId(%s) not found", projectId))) + } + return project, nil +} diff --git a/contract/r/gnoswap/launchpad/launchpad_test.gno b/contract/r/gnoswap/launchpad/launchpad_test.gno new file mode 100644 index 000000000..36ffeac3a --- /dev/null +++ b/contract/r/gnoswap/launchpad/launchpad_test.gno @@ -0,0 +1,1652 @@ +package launchpad + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/obl" + + u256 "gno.land/p/gnoswap/uint256" +) + +func setupTest(t *testing.T) { + projects = make(map[string]Project) + projectTiersWithoutDeposit = make(map[string]bool) +} + +func createTestAddress(t *testing.T) std.Address { + addr := testutils.TestAddress("test") + return addr +} + +func TestValidateName(t *testing.T) { + tests := []struct { + name string + inputName string + wantErr bool + }{ + {"ValidName", "MyProject", false}, + {"EmptyName", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := ProjectInput{Name: tt.inputName} + err := in.ValidateName() + if tt.wantErr { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestValidateTokenPath(t *testing.T) { + tests := []struct { + name string + tokenPath string + wantErr bool + }{ + {"ValidToken", "gno.land/r/gnoswap/v1/gns", false}, + {"EmptyToken", "", true}, + {"UnregisteredToken", "invalid/token", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := ProjectInput{TokenPath: tt.tokenPath} + err := in.ValidateTokenPath() + if tt.wantErr { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestValidateRecipient(t *testing.T) { + tests := []struct { + name string + recipient std.Address + wantErr bool + }{ + {"ValidRecipient", std.Address(consts.ADMIN), false}, + {"InvalidRecipient", std.Address("xyz..."), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := ProjectInput{Recipient: tt.recipient} + err := in.ValidateRecipient() + if tt.wantErr { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestValidateDepositAmount(t *testing.T) { + tests := []struct { + name string + depositAmount uint64 + wantErr bool + }{ + {"Valid", 100, false}, + {"ZeroDeposit", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := ProjectInput{DepositAmount: tt.depositAmount} + err := in.ValidateDepositAmount() + if tt.wantErr { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestValidateRatio(t *testing.T) { + tests := []struct { + name string + t30, t90, t180 uint64 + wantErr bool + }{ + {"ValidSum100", 30, 50, 20, false}, + {"Not100", 10, 10, 10, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := ProjectInput{ + Tier30Ratio: tt.t30, + Tier90Ratio: tt.t90, + Tier180Ratio: tt.t180, + } + err := in.ValidateRatio() + if tt.wantErr { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestValidateStartTime(t *testing.T) { + now := uint64(1000) + tests := []struct { + name string + startTime uint64 + wantErr bool + }{ + {"ValidFuture", 2000, false}, + {"Now", 1000, true}, + {"Past", 999, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := ProjectInput{StartTime: tt.startTime} + err := in.ValidateStartTime(now) + if tt.wantErr { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestParseConditions(t *testing.T) { + tests := []struct { + name string + condToken string + condAmount string + wantErr bool + expectedLen int + }{ + { + name: "ValidMultiple", + condToken: "gno.land/r/demo/wugnot*PAD*gno.land/r/gnoswap/v1/gns", + condAmount: "100*PAD*200", + wantErr: false, + expectedLen: 2, + }, + { + name: "EmptyToken", + condToken: "", + condAmount: "100", + wantErr: true, + expectedLen: 0, + }, + { + name: "LenMismatch", + condToken: "valid/T0*PAD*valid/T1*PAD*valid/T2", + condAmount: "100*PAD*200", + wantErr: true, + expectedLen: 0, + }, + { + name: "UnregisteredToken", + condToken: "invalid/token", + condAmount: "100", + wantErr: true, + expectedLen: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + in := ProjectInput{ + ConditionsToken: tt.condToken, + ConditionsAmount: tt.condAmount, + } + conds, err := in.ParseConditions() + if tt.wantErr { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + uassert.Equal(t, tt.expectedLen, len(conds)) + if tt.expectedLen > 0 { + uassert.Equal(t, uint64(100), conds["gno.land/r/demo/wugnot"].minAmount) + uassert.Equal(t, uint64(200), conds["gno.land/r/gnoswap/v1/gns"].minAmount) + } + } + }) + } +} + +func TestValidateAll(t *testing.T) { + now := uint64(1000) + tests := []struct { + name string + input ProjectInput + wantErr bool + }{ + { + name: "AllValid", + input: ProjectInput{ + Name: "ProjectName", + TokenPath: "gno.land/r/gnoswap/v1/gns", + Recipient: std.Address(consts.ADMIN), + DepositAmount: 1000, + ConditionsToken: "gno.land/r/demo/wugnot", + ConditionsAmount: "10", + Tier30Ratio: 50, + Tier90Ratio: 30, + Tier180Ratio: 20, + StartTime: 2000, + }, + wantErr: false, + }, + { + name: "InvalidName", + input: ProjectInput{ + Name: "", + TokenPath: "valid/token", + Recipient: std.Address("g1valid"), + DepositAmount: 10, + ConditionsToken: "valid/T0", + ConditionsAmount: "1", + Tier30Ratio: 80, + Tier90Ratio: 10, + Tier180Ratio: 10, + StartTime: 2000, + }, + wantErr: true, + }, + { + name: "BadTokenPath", + input: ProjectInput{ + Name: "Proj", + TokenPath: "invalid/token", + Recipient: std.Address("g1valid"), + DepositAmount: 10, + ConditionsToken: "valid/T0", + ConditionsAmount: "1", + Tier30Ratio: 80, + Tier90Ratio: 10, + Tier180Ratio: 10, + StartTime: 2000, + }, + wantErr: true, + }, + { + name: "BadRecipient", + input: ProjectInput{ + Name: "Proj", + TokenPath: "valid/token", + Recipient: std.Address("xinvalid"), + DepositAmount: 10, + ConditionsToken: "valid/T0", + ConditionsAmount: "1", + Tier30Ratio: 80, + Tier90Ratio: 10, + Tier180Ratio: 10, + StartTime: 2000, + }, + wantErr: true, + }, + { + name: "BadDepositAmount", + input: ProjectInput{ + Name: "Proj", + TokenPath: "valid/token", + Recipient: std.Address("g1abc"), + DepositAmount: 0, + ConditionsToken: "valid/T0", + ConditionsAmount: "1", + Tier30Ratio: 80, + Tier90Ratio: 10, + Tier180Ratio: 10, + StartTime: 2000, + }, + wantErr: true, + }, + { + name: "BadRatio", + input: ProjectInput{ + Name: "Proj", + TokenPath: "valid/token", + Recipient: std.Address("g1abc"), + DepositAmount: 100, + ConditionsToken: "valid/T0", + ConditionsAmount: "1", + Tier30Ratio: 80, + Tier90Ratio: 10, + Tier180Ratio: 30, // sum=120 !=100 + StartTime: 2000, + }, + wantErr: true, + }, + { + name: "StartTimePast", + input: ProjectInput{ + Name: "Proj", + TokenPath: "valid/token", + Recipient: std.Address("g1abc"), + DepositAmount: 100, + ConditionsToken: "valid/T0", + ConditionsAmount: "1", + Tier30Ratio: 80, + Tier90Ratio: 10, + Tier180Ratio: 10, + StartTime: 999, // <=now + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.input.ValidateAll(now) + if tt.wantErr { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestValidateProjectInput(t *testing.T) { + now := uint64(time.Now().Unix()) + testCases := []struct { + name string + input ProjectInput + shouldError bool + }{ + { + name: "Valid input", + input: ProjectInput{ + Name: "Test Project", + TokenPath: "gno.land/r/gnoswap/v1/gns", + Recipient: createTestAddress(t), + DepositAmount: 1000, + ConditionsToken: "gno.land/r/demo/wugnot", + ConditionsAmount: "10", + Tier30Ratio: 30, + Tier90Ratio: 30, + Tier180Ratio: 40, + StartTime: now + 3600, // 1 hour in future + }, + shouldError: false, + }, + { + name: "Invalid tier ratios", + input: ProjectInput{ + Name: "Test Project", + TokenPath: "test/token", + Recipient: createTestAddress(t), + DepositAmount: 1000, + Tier30Ratio: 30, + Tier90Ratio: 30, + Tier180Ratio: 30, // Sum != 100 + StartTime: now + 3600, + }, + shouldError: true, + }, + { + name: "Start time in past", + input: ProjectInput{ + Name: "Test Project", + TokenPath: "test/token", + Recipient: createTestAddress(t), + DepositAmount: 1000, + Tier30Ratio: 30, + Tier90Ratio: 30, + Tier180Ratio: 40, + StartTime: now - 3600, // 1 hour in past + }, + shouldError: true, + }, + { + name: "Zero deposit amount", + input: ProjectInput{ + Name: "Test Project", + TokenPath: "test/token", + Recipient: createTestAddress(t), + DepositAmount: 0, + Tier30Ratio: 30, + Tier90Ratio: 30, + Tier180Ratio: 40, + StartTime: now + 3600, + }, + shouldError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.input.ValidateAll(now) + if tc.shouldError && err == nil { + t.Errorf("Expected error but got none") + } + if !tc.shouldError && err != nil { + t.Errorf("Unexpected error: %v", err.Error()) + } + }) + } +} + +func TestProject_GetSet(t *testing.T) { + conditions := make(map[string]Condition) + tiers := make(map[uint64]Tier) + tiersRatios := make(map[uint64]uint64) + + p := NewProject( + "pid", + "MyProject", + "gno.land/r/gnoswap/v1/gns", + 1000, + "g1qxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + conditions, + tiers, + tiersRatios, + TimeInfo{height: 1, time: 1001}, + TimeInfo{height: 2, time: 1002}, + TimeInfo{height: 3, time: 1003}, + *NewProjectStats(10, 20, 30, 40, 50), + RefundInfo{amount: 99, height: 10, time: 999}, + ) + + if p.ID() != "pid" { + t.Errorf("NewProject id mismatch, got=%s", p.ID()) + } + if p.Name() != "MyProject" { + t.Errorf("NewProject name mismatch, got=%s", p.Name()) + } + if p.TokenPath() != "gno.land/r/gnoswap/v1/gns" { + t.Errorf("NewProject tokenPath mismatch, got=%s", p.TokenPath()) + } + if p.DepositAmount() != 1000 { + t.Errorf("NewProject depositAmount mismatch, got=%d", p.DepositAmount()) + } + if p.Recipient().String() != "g1qxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" { + t.Errorf("NewProject recipient mismatch, got=%s", p.Recipient()) + } + created := p.Created() + if created.Height() != 1 { + t.Errorf("NewProject Created Height mismatch, got=%d", created.Height()) + } + ended := p.Ended() + if ended.Time() != 1003 { + t.Errorf("NewProject Ended Time mismatch, got=%d", ended.Time()) + } + stats := p.Stats() + if stats.TotalCollected() != 50 { + t.Errorf("NewProject Stats mismatch, got=%d", stats.TotalCollected()) + } + refund := p.Refund() + if refund.Amount() != 99 { + t.Errorf("NewProject Refund mismatch, got=%d", refund.Amount()) + } + + p.SetID("pid2") + if p.ID() != "pid2" { + t.Errorf("SetID failed, got=%s", p.ID()) + } + p.SetName("MyProject2") + if p.Name() != "MyProject2" { + t.Errorf("SetName failed, got=%s", p.Name()) + } + p.SetTokenPath("new/tokenPath") + if p.TokenPath() != "new/tokenPath" { + t.Errorf("SetTokenPath failed, got=%s", p.TokenPath()) + } + p.SetDepositAmount(2000) + if p.DepositAmount() != 2000 { + t.Errorf("SetDepositAmount failed, got=%d", p.DepositAmount()) + } + t.Run("TestProject_isActivated", func(t *testing.T) { + now := uint64(1) + if p.isActivated(now) { + t.Errorf("isActivated(%d) returned true, want false (before started)", now) + } + + now = uint64(2) + if !p.isActivated(now) { + t.Errorf("isActivated(%d) returned false, want true (exactly started)", now) + } + + now = uint64(2) + if !p.isActivated(now) { + t.Errorf("isActivated(%d) returned false, want true (between start and end)", now) + } + + now = uint64(3) + if p.isActivated(now) { + t.Errorf("isActivated(%d) returned true, want false (exactly ended)", now) + } + + now = uint64(4) + if p.isActivated(now) { + t.Errorf("isActivated(%d) returned true, want false (after ended)", now) + } + }) +} + +func TestTier_GetSet(t *testing.T) { + emptyUint := u256.Zero() + + tr := NewTier( + "tierId", 10, 100, + emptyUint.Clone(), + TimeInfo{height: 0, time: 0}, + TimeInfo{height: 100, time: 9999}, + 1000, 800, 50, 45, 12, 200, + *NewReward(u256.NewUint(1234), 10, 9999), + ) + + if tr.ID() != "tierId" { + t.Errorf("Tier ID mismatch, got=%s", tr.ID()) + } + if tr.CollectWaitDuration() != 10 { + t.Errorf("Tier CollectWaitDuration mismatch, got=%d", tr.CollectWaitDuration()) + } + if tr.TierAmount() != 100 { + t.Errorf("Tier TierAmount mismatch, got=%d", tr.TierAmount()) + } + if tr.TierAmountPerBlockX128().Cmp(u256.Zero()) != 0 { + t.Errorf("Tier TierAmountPerBlockX128 mismatch, expect=0") + } + if tr.TotalDepositAmount() != 1000 { + t.Errorf("Tier TotalDepositAmount mismatch, got=%d", tr.TotalDepositAmount()) + } + if tr.ActualDepositAmount() != 800 { + t.Errorf("Tier ActualDepositAmount mismatch, got=%d", tr.ActualDepositAmount()) + } + if tr.TotalParticipant() != 50 { + t.Errorf("Tier TotalParticipant mismatch, got=%d", tr.TotalParticipant()) + } + if tr.ActualParticipant() != 45 { + t.Errorf("Tier ActualParticipant mismatch, got=%d", tr.ActualParticipant()) + } + if tr.UserCollectedAmount() != 12 { + t.Errorf("Tier UserCollectedAmount mismatch, got=%d", tr.UserCollectedAmount()) + } + if tr.CalculatedAmount() != 200 { + t.Errorf("Tier CalculatedAmount mismatch, got=%d", tr.CalculatedAmount()) + } + + tr.SetID("tierIdNew") + if tr.ID() != "tierIdNew" { + t.Errorf("SetID failed, got=%s", tr.ID()) + } + tr.SetTierAmount(999) + if tr.TierAmount() != 999 { + t.Errorf("SetTierAmount failed, got=%d", tr.TierAmount()) + } + tr.SetActualDepositAmount(777) + if tr.ActualDepositAmount() != 777 { + t.Errorf("SetActualDepositAmount failed, got=%d", tr.ActualDepositAmount()) + } +} + +func TestRewardInfo_GetSet(t *testing.T) { + ri := NewRewardInfo( + u256.NewUint(111), + 222, 333, 444, 555, 666, 777, + ) + if ri.PriceDebt().ToString() != "111" { + t.Errorf("PriceDebt mismatch, got=%s", ri.PriceDebt().ToString()) + } + if ri.DepositAmount() != 222 { + t.Errorf("DepositAmount mismatch, got=%d", ri.DepositAmount()) + } + if ri.RewardAmount() != 333 { + t.Errorf("RewardAmount mismatch, got=%d", ri.RewardAmount()) + } + if ri.Claimed() != 444 { + t.Errorf("Claimed mismatch, got=%d", ri.Claimed()) + } + if ri.StartHeight() != 555 { + t.Errorf("StartHeight mismatch, got=%d", ri.StartHeight()) + } + if ri.GetEndHeight() != 666 { + t.Errorf("EndHeight mismatch, got=%d", ri.GetEndHeight()) + } + if ri.GetLastHeight() != 777 { + t.Errorf("LastHeight mismatch, got=%d", ri.GetLastHeight()) + } + ri.SetPriceDebt(u256.NewUint(999)) + if ri.PriceDebt().ToString() != "999" { + t.Errorf("SetPriceDebt failed, got=%s", ri.PriceDebt().ToString()) + } + ri.SetRewardAmount(1234) + if ri.RewardAmount() != 1234 { + t.Errorf("SetRewardAmount failed, got=%d", ri.RewardAmount()) + } +} + +func TestTimeInfo_GetSet(t *testing.T) { + ti := NewTimeInfo(50, 5000) + + if ti.Height() != 50 { + t.Errorf("CurrentHeight mismatch, got=%d", ti.Height()) + } + if ti.Time() != 5000 { + t.Errorf("Time mismatch, got=%d", ti.Time()) + } + + ti.SetHeight(100) + ti.SetTime(9999) + if ti.Height() != 100 { + t.Errorf("SetHeight failed, got=%d", ti.Height()) + } + if ti.Time() != 9999 { + t.Errorf("SetTime failed, got=%d", ti.Time()) + } +} + +func TestRefundInfo_GetSet(t *testing.T) { + ri := NewRefundInfo(11, 22, 33) + if ri.Amount() != 11 { + t.Errorf("Amount mismatch, got=%d", ri.Amount()) + } + if ri.Height() != 22 { + t.Errorf("CurrentHeight mismatch, got=%d", ri.Height()) + } + if ri.Time() != 33 { + t.Errorf("Time mismatch, got=%d", ri.Time()) + } + + ri.SetAmount(111) + ri.SetHeight(222) + ri.SetTime(333) + if ri.Amount() != 111 { + t.Errorf("SetAmount failed, got=%d", ri.Amount()) + } + if ri.Height() != 222 { + t.Errorf("SetHeight failed, got=%d", ri.Height()) + } + if ri.Time() != 333 { + t.Errorf("SetTime failed, got=%d", ri.Time()) + } +} + +func TestCondition_GetSet(t *testing.T) { + c := NewCondition("my/token", 1000) + if c.TokenPath() != "my/token" { + t.Errorf("Condition TokenPath mismatch, got=%s", c.TokenPath()) + } + if c.MinAmount() != 1000 { + t.Errorf("Condition MinAmount mismatch, got=%d", c.MinAmount()) + } + + c.SetTokenPath("new/token") + if c.TokenPath() != "new/token" { + t.Errorf("SetTokenPath failed, got=%s", c.TokenPath()) + } + c.SetMinAmount(2000) + if c.MinAmount() != 2000 { + t.Errorf("SetMinAmount failed, got=%d", c.MinAmount()) + } +} + +func TestReward_GetSet(t *testing.T) { + acc := u256.NewUint(1000) + r := NewReward(acc, 10, 1000) + + if r.AccumRewardPerDeposit().Cmp(u256.NewUint(1000)) != 0 { + t.Errorf("AccumRewardPerDeposit mismatch, got=%s", r.AccumRewardPerDeposit().ToString()) + } + if r.GetLastHeight() != 10 { + t.Errorf("LastHeight mismatch, got=%d", r.GetLastHeight()) + } + if r.GetEndHeight() != 1000 { + t.Errorf("EndHeight mismatch, got=%d", r.GetEndHeight()) + } + if r.Info() == nil { + t.Errorf("Info mismatch, got=%v, want=not nil", r.Info()) + } + r.SetAccumRewardPerDeposit(u256.NewUint(9999)) + if r.AccumRewardPerDeposit().ToString() != "9999" { + t.Errorf("SetAccumRewardPerDeposit failed, got=%s", r.AccumRewardPerDeposit().ToString()) + } + r.SetLastHeight(111) + if r.GetLastHeight() != 111 { + t.Errorf("SetLastHeight failed, got=%d", r.GetLastHeight()) + } +} + +func TestCreateProject(t *testing.T) { + projectAddr := testutils.TestAddress("projectAddr") + fooBarWithpad := "gno.land/r/onbloc/foo*PAD*gno.land/r/onbloc/bar" + tests := []struct { + name string + projectName string + tokenPath string + projectOwner std.Address + totalAmount uint64 + conditions string + ratios string + tier30Ratio uint64 + tier90Ratio uint64 + tier180Ratio uint64 + startTime uint64 + expectedId string + expectedError string + duplicateTest bool + }{ + { + name: "normal project creation", + projectName: "Obl Protocol", + tokenPath: oblPath, + projectOwner: projectAddr, + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(70), + startTime: uint64(time.Now().Unix() + 10), + expectedId: "gno.land/r/onbloc/obl:124", + expectedError: "", + }, + { + name: "invalid ratio sum (less than 100%)", + projectName: "Invalid Ratio", + tokenPath: oblPath, + projectOwner: projectAddr, + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(60), + startTime: uint64(time.Now().Unix() + 10), + expectedError: "[GNOSWAP-LAUNCHPAD-007] invalid input data || invalid ratio, sum of all tiers(30:10, 90:20, 180:60) should be 100", + }, + { + name: "project name is empty", + projectName: "", + tokenPath: oblPath, + projectOwner: projectAddr, + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(70), + startTime: uint64(time.Now().Unix() + 10), + expectedError: "[GNOSWAP-LAUNCHPAD-007] invalid input data || project name cannot be empty", + }, + { + name: "project owner address is invalid", + projectName: "TTT", + tokenPath: oblPath, + projectOwner: "abcdef", + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(70), + startTime: uint64(time.Now().Unix() + 10), + expectedError: "[GNOSWAP-LAUNCHPAD-002] invalid address || recipient address(abcdef)", + }, + { + name: "invalid start time", + projectName: "Obl Protocol", + tokenPath: oblPath, + projectOwner: projectAddr, + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(70), + startTime: uint64(time.Now().Unix()), + expectedId: "gno.land/r/onbloc/obl:124", + expectedError: "[GNOSWAP-LAUNCHPAD-007] invalid input data || start time(1234567890) must be greater than now(1234567902)", + }, + { + name: "invalid token path", + projectName: "Obl Protocol", + tokenPath: "gno.land/r/test/token", + projectOwner: projectAddr, + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(70), + startTime: uint64(time.Now().Unix() + 50), + expectedId: "gno.land/r/onbloc/obl:124", + expectedError: "[GNOSWAP-LAUNCHPAD-007] invalid input data || tokenPath(gno.land/r/test/token) not registered", + }, + { + name: "duplicate project test - origin", + projectName: "Obl Protocol", + tokenPath: oblPath, + projectOwner: projectAddr, + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(70), + startTime: uint64(time.Now().Unix() + 50), + expectedId: "gno.land/r/onbloc/obl:130", + expectedError: "", + duplicateTest: true, + }, + { + name: "duplicate project test - duplicate", + projectName: "Obl Protocol", + tokenPath: oblPath, + projectOwner: projectAddr, + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(70), + startTime: uint64(time.Now().Unix() + 50), + expectedId: "gno.land/r/onbloc/obl:130", + expectedError: "[GNOSWAP-LAUNCHPAD-008] can not create same project in same block || project(gno.land/r/onbloc/obl:130) already exists", + duplicateTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetRealm(adminRealm) + obl.Approve(consts.LAUNCHPAD_ADDR, tt.totalAmount) + if !tt.duplicateTest { + std.TestSkipHeights(1) + } + + defer func() { + if r := recover(); r != nil { + if tt.expectedError == "" { + t.Errorf("unexpected error: %v", r) + } else { + uassert.Equal(t, r.(string), tt.expectedError) + } + } + }() + + projectId := CreateProject( + tt.projectName, + tt.tokenPath, + tt.projectOwner, + tt.totalAmount, + tt.conditions, + tt.ratios, + tt.tier30Ratio, + tt.tier90Ratio, + tt.tier180Ratio, + tt.startTime, + ) + + if tt.expectedError == "" { + uassert.Equal(t, projectId, tt.expectedId) + } + if !tt.duplicateTest { + std.TestSkipHeights(1) + } + }) + } +} + +func TestCreateProject_TierAndBalance(t *testing.T) { + projectAddr := testutils.TestAddress("projectAddr") + fooBarWithpad := "gno.land/r/onbloc/foo*PAD*gno.land/r/onbloc/bar" + tests := []struct { + name string + projectName string + tokenPath string + projectOwner std.Address + totalAmount uint64 + conditions string + ratios string + tier30Ratio uint64 + tier90Ratio uint64 + tier180Ratio uint64 + startTime uint64 + expectedId string + expectedError string + }{ + { + name: "normal project creation", + projectName: "Obl Protocol", + tokenPath: oblPath, + projectOwner: projectAddr, + totalAmount: uint64(1_000_000_000), + conditions: fooBarWithpad, + ratios: "1*PAD*2", + tier30Ratio: uint64(10), + tier90Ratio: uint64(20), + tier180Ratio: uint64(70), + startTime: uint64(time.Now().Unix() + 10), + expectedId: "gno.land/r/onbloc/obl:124", + expectedError: "[GNOSWAP-LAUNCHPAD-008] can not create same project in same block || project(gno.land/r/onbloc/obl:124) already exists", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetRealm(adminRealm) + obl.Approve(consts.LAUNCHPAD_ADDR, tt.totalAmount) + std.TestSkipHeights(1) + + defer func() { + if r := recover(); r != nil { + if tt.expectedError == "" { + t.Errorf("unexpected error: %v", r) + } else { + uassert.Equal(t, r.(string), tt.expectedError) + } + } + }() + + projectId := CreateProject( + tt.projectName, + tt.tokenPath, + tt.projectOwner, + tt.totalAmount, + tt.conditions, + tt.ratios, + tt.tier30Ratio, + tt.tier90Ratio, + tt.tier180Ratio, + tt.startTime, + ) + + if tt.expectedError == "" { + uassert.Equal(t, projectId, tt.expectedId) + } + std.TestSkipHeights(1) + }) + } +} + +func TestMakeConditions(t *testing.T) { + tests := []struct { + name string + conditionsToken string + conditionsAmount string + wantLen int + wantPanic bool + wantAmounts map[string]uint64 + }{ + { + name: "normal multiple input", + conditionsToken: "gno.land/r/gnoswap/v1/gns*PAD*gno.land/r/onbloc/foo*PAD*gno.land/r/onbloc/bar", + conditionsAmount: "100*PAD*200*PAD*300", + wantLen: 3, + wantPanic: false, + wantAmounts: map[string]uint64{ + "gno.land/r/gnoswap/v1/gns": 100, + "gno.land/r/onbloc/foo": 200, + "gno.land/r/onbloc/bar": 300, + }, + }, + { + name: "empty input", + conditionsToken: "", + conditionsAmount: "", + wantLen: 0, + wantPanic: true, + wantAmounts: nil, + }, + { + name: "single condition", + conditionsToken: "gno.land/r/gnoswap/v1/gns", + conditionsAmount: "100", + wantLen: 1, + wantPanic: false, + wantAmounts: map[string]uint64{ + "gno.land/r/gnoswap/v1/gns": 100, + }, + }, + { + name: "invalid amount format", + conditionsToken: "token1", + conditionsAmount: "invalid", + wantLen: 0, + wantPanic: true, + wantAmounts: nil, + }, + { + name: "token and amount count mismatch", + conditionsToken: "token1*PAD*token2", + conditionsAmount: "100", + wantLen: 0, + wantPanic: true, + wantAmounts: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := ProjectInput{ + Name: tt.name, + TokenPath: "gno.land/r/gnoswap/v1/gns", + Recipient: std.Address(consts.ADMIN), + DepositAmount: 1000000, + ConditionsToken: tt.conditionsToken, + ConditionsAmount: tt.conditionsAmount, + Tier30Ratio: 20, + Tier90Ratio: 30, + Tier180Ratio: 50, + StartTime: uint64(time.Now().Unix() + 100), + } + got, err := input.ParseConditions() + if !tt.wantPanic { + if tt.wantAmounts == nil { + if got != nil { + t.Errorf("makeConditions() = %v, want nil", got) + } + return + } + + count := 0 + for range got { + count++ + } + uassert.Equal(t, count, tt.wantLen) + + for token, wantAmount := range tt.wantAmounts { + condition, exists := got[token] + uassert.True(t, exists) + uassert.Equal(t, condition.minAmount, wantAmount) + } + } else { + uassert.Error(t, err) + } + }) + } +} + +func TestCalculateLeftReward(t *testing.T) { + tests := []struct { + name string + project Project + want uint64 + description string + }{ + { + name: "all tiers have remaining rewards", + project: Project{ + tiers: map[uint64]Tier{ + 30: { + tierAmount: 1000, + calculatedAmount: 400, + }, + 90: { + tierAmount: 2000, + calculatedAmount: 800, + }, + 180: { + tierAmount: 3000, + calculatedAmount: 1200, + }, + }, + }, + want: 3600, // (1000-400) + (2000-800) + (3000-1200) + description: "Should sum up all remaining rewards from each tier", + }, + { + name: "some tiers fully claimed", + project: Project{ + tiers: map[uint64]Tier{ + 30: { + tierAmount: 1000, + calculatedAmount: 1000, + }, + 90: { + tierAmount: 2000, + calculatedAmount: 800, + }, + 180: { + tierAmount: 3000, + calculatedAmount: 3000, + }, + }, + }, + want: 1200, // (1000-1000) + (2000-800) + (3000-3000) + description: "Should handle fully claimed tiers correctly", + }, + { + name: "all tiers fully claimed", + project: Project{ + tiers: map[uint64]Tier{ + 30: { + tierAmount: 1000, + calculatedAmount: 1000, + }, + 90: { + tierAmount: 2000, + calculatedAmount: 2000, + }, + 180: { + tierAmount: 3000, + calculatedAmount: 3000, + }, + }, + }, + want: 0, + description: "Should return 0 when all rewards are claimed", + }, + { + name: "no rewards calculated yet", + project: Project{ + tiers: map[uint64]Tier{ + 30: { + tierAmount: 1000, + calculatedAmount: 0, + }, + 90: { + tierAmount: 2000, + calculatedAmount: 0, + }, + 180: { + tierAmount: 3000, + calculatedAmount: 0, + }, + }, + }, + want: 6000, // 1000 + 2000 + 3000 + description: "Should return total tier amounts when no rewards calculated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := calculateLeftReward(tt.project) + uassert.Equal(t, got, tt.want) + }) + } +} + +func TestCreateTier(t *testing.T) { + // Constants for test + const ( + AVG_BLOCK_TIME_MS = 2000 + TIMESTAMP_DAY = 86400 + ) + + tests := []struct { + name string + projectId string + duration uint64 + amount uint64 + startHeight uint64 + startTime uint64 + collectWaitDuration uint64 + want Tier + description string + }{ + { + name: "30-day tier creation", + projectId: "project1", + duration: 30, + amount: 1000, + startHeight: 1000, + startTime: 1600000000, + collectWaitDuration: 100, + want: Tier{ + id: "project1:30", + collectWaitDuration: 100, + tierAmount: 1000, + ended: TimeInfo{ + height: 1000 + (30 * TIMESTAMP_DAY * 1000 / AVG_BLOCK_TIME_MS), + time: 1600000000 + (30 * TIMESTAMP_DAY), + }, + }, + description: "Should create a 30-day tier with correct end time and height", + }, + { + name: "90-day tier creation", + projectId: "project2", + duration: 90, + amount: 2000, + startHeight: 2000, + startTime: 1700000000, + collectWaitDuration: 200, + want: Tier{ + id: "project2:90", + collectWaitDuration: 200, + tierAmount: 2000, + ended: TimeInfo{ + height: 2000 + (90 * TIMESTAMP_DAY * 1000 / AVG_BLOCK_TIME_MS), + time: 1700000000 + (90 * TIMESTAMP_DAY), + }, + }, + description: "Should create a 90-day tier with correct end time and height", + }, + { + name: "180-day tier creation", + projectId: "project3", + duration: 180, + amount: 3000, + startHeight: 3000, + startTime: 1800000000, + collectWaitDuration: 300, + want: Tier{ + id: "project3:180", + collectWaitDuration: 300, + tierAmount: 3000, + ended: TimeInfo{ + height: 3000 + (180 * TIMESTAMP_DAY * 1000 / AVG_BLOCK_TIME_MS), + time: 1800000000 + (180 * TIMESTAMP_DAY), + }, + }, + description: "Should create a 180-day tier with correct end time and height", + }, + { + name: "zero duration tier", + projectId: "project4", + duration: 0, + amount: 1000, + startHeight: 4000, + startTime: 1900000000, + collectWaitDuration: 400, + want: Tier{ + id: "project4:0", + collectWaitDuration: 400, + tierAmount: 1000, + ended: TimeInfo{ + height: 4000, + time: 1900000000, + }, + }, + description: "Should handle zero duration correctly", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + projectTiersWithoutDeposit = make(map[string]bool) + + got := createTier( + tt.projectId, + tt.duration, + tt.amount, + tt.startHeight, + tt.startTime, + tt.collectWaitDuration, + ) + + uassert.Equal(t, got.id, tt.want.id) + uassert.Equal(t, got.collectWaitDuration, tt.want.collectWaitDuration) + uassert.Equal(t, got.tierAmount, tt.want.tierAmount) + uassert.Equal(t, got.ended.height, tt.want.ended.height) + uassert.Equal(t, got.ended.time, tt.want.ended.time) + uassert.True(t, projectTiersWithoutDeposit[got.id]) + }) + } +} + +func TestCalculateProjectParams(t *testing.T) { + currentTime := uint64(time.Now().Unix()) + baseHeight := uint64(std.GetHeight()) + + tests := []struct { + name string + input ProjectInput + now uint64 + want *ProjectCalculationResult + wantErr bool + }{ + { + name: "valid input with equal ratio distribution", + input: ProjectInput{ + DepositAmount: 3000, + Tier30Ratio: 33, + Tier90Ratio: 33, + Tier180Ratio: 34, + StartTime: currentTime + 1000, // 1000 seconds in future + }, + now: currentTime, + want: &ProjectCalculationResult{ + Tier30Amount: 990, // 3000 * 33 / 100 + Tier90Amount: 990, // 3000 * 33 / 100 + Tier180Amount: 1020, // 3000 * 34 / 100 + StartHeight: baseHeight + 500, // 1000 seconds = 500 blocks + }, + wantErr: false, + }, + { + name: "valid input with uneven ratio distribution", + input: ProjectInput{ + DepositAmount: 10000, + Tier30Ratio: 20, + Tier90Ratio: 30, + Tier180Ratio: 50, + StartTime: currentTime + 2000, // 2000 seconds in future + }, + now: currentTime, + want: &ProjectCalculationResult{ + Tier30Amount: 2000, // 10000 * 20 / 100 + Tier90Amount: 3000, // 10000 * 30 / 100 + Tier180Amount: 5000, // 10000 * 50 / 100 + StartHeight: baseHeight + 1000, // 2000 seconds = 1000 blocks + }, + wantErr: false, + }, + { + name: "start time in the past", + input: ProjectInput{ + DepositAmount: 1000, + Tier30Ratio: 30, + Tier90Ratio: 30, + Tier180Ratio: 40, + StartTime: currentTime - 1000, + }, + now: currentTime, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := calculateProjectParams(tt.input, tt.now) + + if tt.wantErr { + uassert.Error(t, err) + return + } + + uassert.NoError(t, err) + uassert.Equal(t, got.Tier30Amount, tt.want.Tier30Amount) + uassert.Equal(t, got.Tier90Amount, tt.want.Tier90Amount) + uassert.Equal(t, got.Tier180Amount, tt.want.Tier180Amount) + uassert.Equal(t, got.StartHeight, tt.want.StartHeight) + }) + } +} + +func TestGenerateProjectId(t *testing.T) { + setupTest(t) + + currentHeight := std.GetHeight() + + tests := []struct { + name string + tokenPath string + mockHeight int64 + want string + }{ + { + name: "normal token path", + tokenPath: "gno.land/r/gnoswap/gns", + mockHeight: currentHeight + 10, + want: "", + }, + { + name: "another token path", + tokenPath: "gno.land/r/onbloc/obl", + mockHeight: currentHeight + 5, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + heightBefore := std.GetHeight() + skip := tt.mockHeight - heightBefore + if skip > 0 { + std.TestSkipHeights(skip) + } + + got := generateProjectId(tt.tokenPath) + + want := ufmt.Sprintf("%s:%d", tt.tokenPath, std.GetHeight()) + + uassert.Equal(t, got, want) + }) + } +} + +func TestGenerateTierId(t *testing.T) { + setupTest(t) + + tests := []struct { + name string + projectId string + duration uint64 + want string + }{ + { + name: "30-day tier", + projectId: "gno.land/r/gnoswap/gns:100", + duration: 30, + want: "gno.land/r/gnoswap/gns:100:30", + }, + { + name: "90-day tier", + projectId: "gno.land/r/onbloc/obl:999", + duration: 90, + want: "gno.land/r/onbloc/obl:999:90", + }, + { + name: "180-day tier", + projectId: "myProject:1234", + duration: 180, + want: "myProject:1234:180", + }, + { + name: "zero-day tier", + projectId: "testProj:555", + duration: 0, + want: "testProj:555:0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := generateTierId(tt.projectId, tt.duration) + uassert.Equal(t, got, tt.want) + }) + } +} + +func TestCovertTimeToHeight(t *testing.T) { + setupTest(t) + + tests := []struct { + name string + timestamp uint64 + want uint64 + }{ + { + name: "3600 seconds => 1800 blocks", + timestamp: 3600, // 1hour + want: 1800, // 3600 * 1000 / 2000 + }, + { + name: "0 seconds => 0 blocks", + timestamp: 0, + want: 0, + }, + { + name: "small seconds => 1 block", + timestamp: 2, // 2 seconds + want: 1, // 2 * 1000 / 2000 = 1 + }, + { + name: "larger value", + timestamp: 86400, // 1day (seconds) + want: 43200, // 86400 * 1000 / 2000 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := convertTimeToHeight(tt.timestamp) + uassert.Equal(t, got, tt.want) + }) + } + + // division by zero test + t.Run("blocktime is zero => panic", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but got none") + } else { + uassert.Equal(t, r.(string), "division by zero") + } + }() + + std.TestSetRealm(adminRealm) + gns.SetAvgBlockTimeInMsByAdmin(0) + _ = convertTimeToHeight(1000) + }) +} + +func TestTransferLeftFromProjectByAdmin(t *testing.T) { + setupTest(t) + + validRecipient := testutils.TestAddress("recipient") + invalidRecipient := std.Address("invalidRecipient") + + projectId := "dummyProject:100" + dummyProject := Project{ + id: projectId, + name: "Dummy Project", + tokenPath: oblPath, + tiers: map[uint64]Tier{ + 30: {tierAmount: 1000, userCollectedAmount: 400}, + 90: {tierAmount: 2000, userCollectedAmount: 2000}, + 180: {tierAmount: 3000, userCollectedAmount: 2500}, + }, + ended: TimeInfo{ + height: 120, + time: uint64(time.Now().Unix() - 1), + }, + } + + tests := []struct { + name string + setupFunc func() + projectId string + recipient std.Address + mockHeight uint64 + expectedPanic string + }{ + { + name: "project not found", + setupFunc: func() { + projects = make(map[string]Project) + }, + projectId: "nonExistentProject", + recipient: validRecipient, + mockHeight: 200, + expectedPanic: "[GNOSWAP-LAUNCHPAD-003] requested data not found || projectId(nonExistentProject) not found", + }, + { + name: "invalid recipient address", + setupFunc: func() { + projects = make(map[string]Project) + projects[projectId] = dummyProject + }, + projectId: projectId, + recipient: invalidRecipient, + mockHeight: 200, + expectedPanic: "[GNOSWAP-LAUNCHPAD-013] invalid transfer condition || invalid recipient address(invalidRecipient)", + }, + { + name: "project not ended yet", + setupFunc: func() { + projects = make(map[string]Project) + p := dummyProject + p.ended.height = 300 + projects[projectId] = p + }, + projectId: projectId, + recipient: validRecipient, + mockHeight: 250, + expectedPanic: "[GNOSWAP-LAUNCHPAD-013] invalid transfer condition || project not ended yet(current:250, endHeight: 300)", + }, + { + name: "project already refunded", + setupFunc: func() { + projects = make(map[string]Project) + p := dummyProject + p.refund = RefundInfo{ + amount: 10, + height: 110, + time: uint64(time.Now().Unix()), + } + projects[projectId] = p + }, + projectId: projectId, + recipient: validRecipient, + mockHeight: 200, + expectedPanic: "[GNOSWAP-LAUNCHPAD-013] invalid transfer condition || project already refunded(height:110)", + }, + { + name: "success leftover transfer", + setupFunc: func() { + projects = make(map[string]Project) + p := dummyProject + p.ended.height = 120 + p.depositAmount = 6000 + p.stats.totalCollected = 4900 + p.refund = RefundInfo{} + projects[projectId] = p + }, + projectId: projectId, + recipient: validRecipient, + mockHeight: 250, + expectedPanic: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetRealm(adminRealm) + // set up + tt.setupFunc() + + // control block height + currHeight := std.GetHeight() + skip := int64(tt.mockHeight) - int64(currHeight) + if skip > 0 { + std.TestSkipHeights(skip) + } + + defer func() { + if r := recover(); r != nil { + if tt.expectedPanic == "" { + t.Errorf("unexpected panic: %v", r) + } else { + uassert.Equal(t, r.(string), tt.expectedPanic) + } + } else if tt.expectedPanic != "" { + t.Errorf("expected panic(%s) but got none", tt.expectedPanic) + } + }() + + // run + got := TransferLeftFromProjectByAdmin(tt.projectId, tt.recipient) + + if tt.expectedPanic == "" { + project := projects[tt.projectId] + // leftover = (1000-400) + (2000-2000) + (3000-2500) = 1100 + uassert.Equal(t, got, uint64(1100)) + uassert.Equal(t, project.refund.amount, uint64(1100)) + uassert.Equal(t, project.refund.height, tt.mockHeight) + } + }) + } +} diff --git a/contract/r/gnoswap/launchpad/reward.gno b/contract/r/gnoswap/launchpad/reward.gno new file mode 100644 index 000000000..f4e29e503 --- /dev/null +++ b/contract/r/gnoswap/launchpad/reward.gno @@ -0,0 +1,282 @@ +package launchpad + +import ( + "errors" + "std" + "time" + + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" + en "gno.land/r/gnoswap/v1/emission" + gs "gno.land/r/gnoswap/v1/gov/staker" +) + +// RewardCalculationResult represents the result of reward calculation +type RewardCalculationResult struct { + Reward uint64 + UpdatedTiers map[string]Tier +} + +// CollectProtocolFee collects protocol fee from gov/staker +// each project's recipient wallet will be rewarded +// ref: https://docs.gnoswap.io/contracts/launchpad/launchpad_reward.gno#collectprotocolfee +func CollectProtocolFee() { + caller := getPrevAddr() + gs.CollectRewardFromLaunchPad(caller) +} + +// CollectRewardByProjectId collects reward from entire deposit of certain project by caller +// Returns collected reward amount +// ref: https://docs.gnoswap.io/contracts/launchpad/launchpad_reward.gno#collectrewardbyprojectid +func CollectRewardByProjectId(projectId string) uint64 { + assertOnlyNotHalted() + + project, err := getProject(projectId) + if err != nil { + panic(err.Error()) + } + caller := getPrevAddr() + if _, exists := depositsByUserByProject[caller]; !exists { + panic(addDetailToError(errInvalidCaller, ufmt.Sprintf("caller(%s) not found", caller))) + } + depositIds, exists := depositsByUserByProject[caller][projectId] + if !exists { + panic(addDetailToError( + errNotExistDeposit, ufmt.Sprintf("depositId (%s) not found", projectId))) + } + en.MintAndDistributeGns() + + currentHeight := uint64(std.GetHeight()) + totalReward := uint64(0) // toUser + prevPkgPath := getPrevPkgPath() + + for _, depositId := range depositIds { + deposit, ok := deposits[depositId] + if !ok { + continue + } + + if !deposit.isClaimable(currentHeight) { + continue + } + if err = validateRewardCollection(deposit, currentHeight); err != nil { + panic(err.Error()) + } + + tier, err := project.Tier(convertTierTypeStrToUint64(deposit.Tier())) + if err != nil { + panic(err.Error()) + } + reward := tier.Reward() + rewardPerDeposit, err := reward.finalize(currentHeight, tier.ActualDepositAmount(), tier.TierAmountPerBlockX128()) + if err != nil { + panic(err.Error()) + } + reward.addRewardPerDeposit(rewardPerDeposit) + reward.SetLastHeight(currentHeight) + rewardAmount := reward.deductReward(depositId, currentHeight) + if rewardAmount == 0 { + continue + } + + tier.userCollectedAmount += rewardAmount + project.Stats().totalCollected += rewardAmount + + totalReward += rewardAmount + + std.Emit( + "CollectRewardByProjectId", + "prevAddr", caller.String(), + "prevRealm", prevPkgPath, + "projectId", projectId, + "depositId", depositId, + "amount", formatUint(rewardAmount), + ) + project.setTier(convertTierTypeStrToUint64(deposit.Tier()), tier) + + // Update deposit + deposit.SetRewardCollected(deposit.RewardCollected() + rewardAmount) + deposit.SetRewardCollectHeight(currentHeight) + deposit.SetRewardCollectTime(uint64(time.Now().Unix())) + deposits[depositId] = deposit + } + + projects[projectId] = project + + if totalReward > 0 { + tokenTeller := common.GetTokenTeller(project.tokenPath) + tokenTeller.Transfer(caller, totalReward) + } + + return totalReward +} + +// CollectRewardByDepositId collects reward from certain deposit by caller +// Returns collected reward amount +// ref: https://docs.gnoswap.io/contracts/launchpad/launchpad_reward.gno#collectrewardbydepositid +func CollectRewardByDepositId(depositId string) uint64 { + assertOnlyNotHalted() + + deposit, exists := deposits[depositId] + if !exists { + panic(ufmt.Sprintf("depositId(%s) not found", depositId)) + } + + currentHeight := uint64(std.GetHeight()) + if err := validateRewardCollection(deposit, currentHeight); err != nil { + panic(err.Error()) + } + + project, err := getProject(deposit.ProjectID()) + if err != nil { + panic(err.Error()) + } + + en.MintAndDistributeGns() + + tier, err := project.Tier(convertTierTypeStrToUint64(deposit.Tier())) + if err != nil { + panic(err.Error()) + } + + reward := tier.Reward() + if reward.GetLastHeight() == currentHeight { + panic(addDetailToError(errAlreadyCollected, ufmt.Sprintf("depositId(%s)", depositId))) + } + + rewardPerDeposit, err := reward.finalize(currentHeight, tier.ActualDepositAmount(), tier.TierAmountPerBlockX128()) + if err != nil { + panic(err.Error()) + } + reward.addRewardPerDeposit(rewardPerDeposit) + reward.SetLastHeight(currentHeight) + + rewardAmount := reward.deductReward(depositId, currentHeight) + if rewardAmount == 0 { + panic(addDetailToError(errAlreadyCollected, ufmt.Sprintf("depositId(%s)", depositId))) + } + + tier.SetReward(reward) + tier.userCollectedAmount += rewardAmount + project.Stats().totalCollected += rewardAmount + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "CollectRewardByDepositId", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "depositId", depositId, + "amount", formatUint(rewardAmount), + ) + + project.setTier(convertTierTypeStrToUint64(deposit.Tier()), tier) + + // Update deposit + deposit.SetRewardCollected(deposit.RewardCollected() + rewardAmount) + deposit.SetRewardCollectHeight(currentHeight) + deposit.SetRewardCollectTime(uint64(time.Now().Unix())) + deposits[depositId] = deposit + + projects[deposit.projectId] = project + + // Transfer reward + if rewardAmount > 0 { + tokenTeller := common.GetTokenTeller(project.tokenPath) + tokenTeller.Transfer(getPrevAddr(), rewardAmount) + } + + return rewardAmount +} + +// validateRewardCollection validates if reward can be collected +func validateRewardCollection(deposit Deposit, height uint64) error { + caller := getPrevAddr() + if err := validateRewardClaimable(deposit, height); err != nil { + return err + } + if err := validateRewardOwner(deposit.id, caller); err != nil { + return err + } + return nil +} + +// calculateTierRewards calculates rewards for each tier +func calculateTierRewards(tier Tier, currentHeight uint64, lastCalcHeight uint64) (*u256.Uint, uint64, error) { + if tier.tierAmountPerBlockX128 == nil { + return u256.Zero(), 0, nil + } + + if tier.tierAmountPerBlockX128.IsZero() { + return u256.Zero(), 0, nil + } + + endHeight := minU64(currentHeight, tier.ended.height) + sinceLast := endHeight - minU64(endHeight, lastCalcHeight) + + if sinceLast == 0 { + return u256.Zero(), 0, nil + } + + rewardX128 := new(u256.Uint).Mul(tier.tierAmountPerBlockX128.NilToZero().Clone(), u256.NewUint(sinceLast)) + reward := new(u256.Uint).Div(rewardX128, q128).Uint64() + + return rewardX128, reward, nil +} + +// calcDepositRatioX128 calculates the deposit ratio with Q128 precision +func calcDepositRatioX128(tierAmount uint64, amount uint64) *u256.Uint { + amountX128 := new(u256.Uint).Mul(u256.NewUint(amount), q128.Clone()) + amountX128x := new(u256.Uint).Mul(amountX128, u256.NewUint(1_000_000_000)) + + tierAmountX128 := new(u256.Uint).Mul(u256.NewUint(tierAmount), q128.Clone()) + + depositRatioX96 := new(u256.Uint).Div(amountX128x, tierAmountX128) + depositRatioX96 = new(u256.Uint).Mul(depositRatioX96, q128.Clone()) + depositRatioX96 = new(u256.Uint).Div(depositRatioX96, u256.NewUint(1_000_000_000)) + return depositRatioX96 +} + +// calcProjectTiersRewardPerBlockX96 calculates the reward per block with Q96 precision +func calcProjectTiersRewardPerBlockX128(tier Tier) *u256.Uint { + tierAmountX128 := new(u256.Uint).Mul(u256.NewUint(tier.tierAmount), q128) + return new(u256.Uint).Div(tierAmountX128, u256.NewUint(tier.ended.height-tier.started.height)) +} + +// calcRewardPerBlockX128 calculates the reward per block with Q128 precision +func calcRewardPerBlockX128(tierAmount uint64, duration uint64) *u256.Uint { + tierAmountX128 := new(u256.Uint).Mul(u256.NewUint(tierAmount), q128) + return new(u256.Uint).Div(tierAmountX128, u256.NewUint(duration)) +} + +// validateRewardOwner checks if the deposit is truly owned by `caller`. +func validateRewardOwner(depositId string, caller std.Address) error { + dep, ok := deposits[depositId] + if !ok { + return errors.New(addDetailToError(errNotExistDeposit, ufmt.Sprintf("depositId(%s)", depositId))) + } + projMap, userHasProject := depositsByUserByProject[caller] + if !userHasProject { + return errors.New(addDetailToError(errInvalidCaller, ufmt.Sprintf("caller(%s) not found", caller))) + } + depositIds, userHasDeposits := projMap[dep.projectId] + if !userHasDeposits { + return errors.New(addDetailToError(errNotExistDeposit, ufmt.Sprintf("caller(%s) no deposit for project(%s)", caller, dep.projectId))) + } + if !containsString(depositIds, depositId) { + return errors.New(addDetailToError(errInvalidOwner, ufmt.Sprintf("caller(%s) does not own depositId(%s)", caller, depositId))) + } + return nil +} + +// validateRewardClaimable checks if reward can be claimed +// e.g. deposit.claimableHeight, deposit.claimableTime +func validateRewardClaimable(deposit Deposit, currentHeight uint64) error { + if deposit.rewardCollectTime == 0 { + if currentHeight < deposit.claimableHeight { + return errors.New(addDetailToError(errNotYetClaimReward, ufmt.Sprintf("need block >= %d, current=%d", deposit.claimableHeight, currentHeight))) + } + } + return nil +} diff --git a/contract/r/gnoswap/launchpad/reward_calculation.gno b/contract/r/gnoswap/launchpad/reward_calculation.gno new file mode 100644 index 000000000..befe51214 --- /dev/null +++ b/contract/r/gnoswap/launchpad/reward_calculation.gno @@ -0,0 +1,248 @@ +// Copied from gov/staker, extract it out into acommon package? +package launchpad + +import ( + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" +) + +type StakerRewardInfo struct { + StartHeight uint64 // height when launchpad started staking + PriceDebt *u256.Uint // price debt per GNS stake, Q128 + Amount uint64 // amount of GNS staked + Claimed uint64 // amount of reward claimed so far +} + +func (self *StakerRewardInfo) Debug() string { + return ufmt.Sprintf("{ StartHeight: %d, PriceDebt: %d, Amount: %d, Claimed: %d }", self.StartHeight, self.PriceDebtUint64(), self.Amount, self.Claimed) +} + +func (self *StakerRewardInfo) PriceDebtUint64() uint64 { + return u256.Zero().Rsh(self.PriceDebt, 128).Uint64() +} + +type RewardState struct { + PriceAccumulation *u256.Uint // claimable Launchpad reward per xGNS stake, Q128 + TotalStake uint64 // total xGNS staked + + LastHeight uint64 // last height when reward was calculated + RewardPerBlock *u256.Uint // reward per block, = Tier.tierAmountPerBlockX128 + EndHeight uint64 + + TotalEmptyBlock uint64 + + info *avl.Tree // depositId -> StakerRewardInfo +} + +func NewRewardState(rewardPerBlock *u256.Uint, startHeight uint64, endHeight uint64) *RewardState { + return &RewardState{ + PriceAccumulation: u256.Zero(), + TotalStake: 0, + LastHeight: startHeight, + RewardPerBlock: rewardPerBlock, + EndHeight: endHeight, + info: avl.NewTree(), + } +} + +func (self *RewardState) PriceAccumulationUint64() uint64 { + return u256.Zero().Rsh(self.PriceAccumulation, 128).Uint64() +} + +func (self *RewardState) RewardPerBlockUint64() uint64 { + return u256.Zero().Rsh(self.RewardPerBlock, 128).Uint64() +} + +func (self *RewardState) Debug() string { + return ufmt.Sprintf("{ PriceAccumulation: %d, TotalStake: %d, LastHeight: %d, RewardPerBlock: %d, EndHeight: %d, info: len(%d) }", self.PriceAccumulationUint64(), self.TotalStake, self.LastHeight, self.RewardPerBlockUint64(), self.EndHeight, self.info.Size()) +} + +type RewardStates struct { + states *avl.Tree // projectId:tier string -> RewardState +} + +var rewardStates = RewardStates{ + states: avl.NewTree(), +} + +func (states RewardStates) Get(projectId string, tierStr string) (*RewardState, error) { + key := projectId + ":" + tierStr + statesI, exists := states.states.Get(key) + if !exists { + return nil, ufmt.Errorf("reward state not found for projectId %s and tierStr %s", projectId, tierStr) + } + return statesI.(*RewardState), nil +} + +func (states RewardStates) Set(projectId string, tierStr string, state *RewardState) { + key := projectId + ":" + tierStr + states.states.Set(key, state) +} + +func (states RewardStates) DeleteProject(projectId string) uint64 { + totalLeftover := uint64(0) + keys := []string{} + states.states.Iterate(projectId+":", projectId+";", func(key string, value interface{}) bool { + state := value.(*RewardState) + totalEmptyBlock := state.TotalEmptyBlock + if state.TotalStake == 0 { + totalEmptyBlock += state.EndHeight - state.LastHeight + } + totalEmptyBlockU256 := u256.NewUint(totalEmptyBlock) + totalEmptyRewards := u256.Zero().Mul(totalEmptyBlockU256, state.RewardPerBlock) + totalLeftover += totalEmptyRewards.Uint64() + keys = append(keys, key) + return false + }) + + for _, key := range keys { + states.states.Remove(key) + } + + return totalLeftover +} + +func (self *RewardState) Info(depositId string) StakerRewardInfo { + infoI, exists := self.info.Get(depositId) + if !exists { + panic(addDetailToError( + errNotExistDeposit, ufmt.Sprintf("(%s)", depositId))) + } + return infoI.(StakerRewardInfo) +} + +func (self *RewardState) CalculateReward(depositId string) uint64 { + info := self.Info(depositId) + stakerPrice := u256.Zero().Sub(self.PriceAccumulation, info.PriceDebt) + reward := stakerPrice.Mul(stakerPrice, u256.NewUint(info.Amount)) + reward = reward.Rsh(reward, 128) + return reward.Uint64() - info.Claimed +} + +// amount MUST be less than or equal to the amount of xGNS staked +// This function does not check it +func (self *RewardState) deductReward(depositId string, currentHeight uint64) uint64 { + if currentHeight < self.LastHeight { + panic(addDetailToError( + errInvalidRewardState, + ufmt.Sprintf("currentHeight %d is less than LastHeight %d", currentHeight, self.LastHeight))) + } + + deposit := deposits[depositId] + if deposit.claimableHeight > currentHeight { + panic(addDetailToError( + errInvalidRewardState, + ufmt.Sprintf("currentHeight %d is less than claimableHeight %d", currentHeight, deposit.claimableHeight))) + } + + info := self.Info(depositId) + if info.StartHeight > currentHeight { + panic(addDetailToError( + errInvalidRewardState, + ufmt.Sprintf("currentHeight %d is less than StartHeight %d", currentHeight, info.StartHeight))) + } + reward64 := self.CalculateReward(depositId) + + if reward64 == 0 { + return 0 + } + + info.Claimed += reward64 + self.info.Set(depositId, info) + + return reward64 +} + +// This function MUST be called as a part of AddStake or RemoveStake +// CurrentBalance / StakeChange / IsRemoveStake will be updated in those functions +func (self *RewardState) finalize(currentHeight uint64) { + if currentHeight <= self.LastHeight { + // Not started yet + return + } + if self.LastHeight >= self.EndHeight { + // already ended + return + } + if currentHeight > self.EndHeight { + currentHeight = self.EndHeight + } + + diff := currentHeight - self.LastHeight + delta := u256.NewUint(diff) + delta = delta.Mul(delta, self.RewardPerBlock) + + if self.TotalStake == uint64(0) { + self.TotalEmptyBlock += diff + self.LastHeight = currentHeight + return + } + + price := delta.Div(delta, u256.NewUint(self.TotalStake)) + self.PriceAccumulation = u256.Zero().Add(self.PriceAccumulation, price) + self.LastHeight = currentHeight +} + +func (self *RewardState) AddStake(currentHeight uint64, depositId string, amount uint64) { + if self.info.Has(depositId) { + panic(addDetailToError( + errAlreadyExistDeposit, + ufmt.Sprintf("depositId %s already exists", depositId))) + } + + self.finalize(currentHeight) + + self.TotalStake += amount + + if self.info.Has(depositId) { + info := self.Info(depositId) + info.PriceDebt.Add(info.PriceDebt, u256.NewUint(info.Amount)) + info.PriceDebt.Add(info.PriceDebt, u256.Zero().Mul(self.PriceAccumulation, u256.NewUint(amount))) + info.PriceDebt.Div(info.PriceDebt, u256.NewUint(self.TotalStake)) + info.Amount += amount + self.info.Set(depositId, info) + return + } + + info := StakerRewardInfo{ + StartHeight: currentHeight, + PriceDebt: self.PriceAccumulation.Clone(), + Amount: amount, + Claimed: 0, + } + + self.info.Set(depositId, info) +} + +func (self *RewardState) Claim(depositId string, currentHeight uint64) uint64 { + if !self.info.Has(depositId) { + panic(addDetailToError( + errNotExistDeposit, + ufmt.Sprintf("depositId %s", depositId))) + } + + self.finalize(currentHeight) + + reward := self.deductReward(depositId, currentHeight) + + return reward +} + +func (self *RewardState) RemoveStake(depositId string, amount uint64, currentHeight uint64) uint64 { + if !self.info.Has(depositId) { + panic(addDetailToError( + errNotExistDeposit, + ufmt.Sprintf("depositId %s", depositId))) + } + + self.finalize(currentHeight) + + reward := self.deductReward(depositId, currentHeight) + + self.info.Remove(depositId) + + self.TotalStake -= amount + + return reward +} diff --git a/contract/r/gnoswap/launchpad/reward_calculation_test.gno b/contract/r/gnoswap/launchpad/reward_calculation_test.gno new file mode 100644 index 000000000..be55c2a19 --- /dev/null +++ b/contract/r/gnoswap/launchpad/reward_calculation_test.gno @@ -0,0 +1,233 @@ +package launchpad + +import ( + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + + u256 "gno.land/p/gnoswap/uint256" +) + +func TestRewardState_BasicFlow(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) // 1000 * 2^128 + startHeight := uint64(100) + endHeight := uint64(1000) + state := NewRewardState(rewardPerBlock, startHeight, endHeight) + + currentHeight := uint64(150) + depositId := "deposit1" + stakeAmount := uint64(500) + state.AddStake(currentHeight, depositId, stakeAmount) + + uassert.Equal(t, state.TotalStake, stakeAmount, ufmt.Sprintf("TotalStake = %d, want %d", state.TotalStake, stakeAmount)) + + claimHeight := uint64(200) + reward := state.Claim(depositId, claimHeight) + uassert.NotEqual(t, reward, uint64(0), "Expected non-zero reward") + + removeHeight := uint64(250) + finalReward := state.RemoveStake(depositId, stakeAmount, removeHeight) + uassert.Equal(t, state.TotalStake, uint64(0), ufmt.Sprintf("TotalStake after removal = %d, want 0", state.TotalStake)) +} + +func TestRewardState_MultipleStakers(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + state.AddStake(150, "deposit1", 300) + state.AddStake(160, "deposit2", 200) + + uassert.Equal(t, state.TotalStake, uint64(500), ufmt.Sprintf("TotalStake = %d, want 500", state.TotalStake)) + + // check first staker's reward + reward1 := state.Claim("deposit1", 200) + uassert.NotEqual(t, reward1, uint64(0), "Expected non-zero reward for deposit1") + + // check 2nd staker's reward + reward2 := state.Claim("deposit2", 200) + uassert.NotEqual(t, reward2, uint64(0), "Expected non-zero reward for deposit2") + + // reward1 must be greater than reward2 + // because reward1 staked earlier than reward2 + if reward1 <= reward2 { + t.Error("Expected reward1 > reward2") + } +} + +func TestRewardState_EmptyBlocks(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + // call `finalize` while the block is empty + state.finalize(200) + uassert.Equal(t, state.TotalEmptyBlock, uint64(100), ufmt.Sprintf("TotalEmptyBlock = %d, want 100", state.TotalEmptyBlock)) + + // block cound must be stopped when staker added + state.AddStake(250, "deposit1", 100) + state.finalize(300) + + uassert.Equal(t, state.TotalEmptyBlock, uint64(150), ufmt.Sprintf("TotalEmptyBlock after staking = %d, want 150", state.TotalEmptyBlock)) +} + +func TestRewardState_EndHeightBehavior(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + startHeight := uint64(100) + endHeight := uint64(200) + state := NewRewardState(rewardPerBlock, startHeight, endHeight) + + // stake before endHeight + state.AddStake(150, "deposit1", 100) + + // claim after endHeight + reward := state.Claim("deposit1", 250) + + // must be no additional reward after endHeight + secondReward := state.Claim("deposit1", 300) + uassert.Equal(t, secondReward, uint64(0), ufmt.Sprintf("Got reward after endHeight: %d, want 0", secondReward)) +} + +func TestRewardState_PartialStakeRemoval(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + state.AddStake(150, "deposit1", 500) + + // checking reward after partial removal + reward1 := state.RemoveStake("deposit1", 200, 200) + reward2 := state.Claim("deposit1", 250) + + uassert.Equal(t, state.TotalStake, uint64(300), ufmt.Sprintf("TotalStake after partial removal = %d, want 300", state.TotalStake)) +} + +func TestRewardState_BoundaryCases(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + // stake before startHeight + state.AddStake(50, "deposit1", 100) + if state.LastHeight != 100 { + t.Errorf("LastHeight = %d, want 100", state.LastHeight) + } + + state.AddStake(150, "deposit2", 1) + state.AddStake(150, "deposit3", ^uint64(0)-1) + + reward1 := state.Claim("deposit2", 200) + reward2 := state.Claim("deposit2", 200) + uassert.Equal(t, reward2, uint64(0), ufmt.Sprintf("Second claim at same height should be 0, got %d", reward2)) +} + +func TestRewardState_MultipleStakesFromSameUser(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + // stake from the same user at different times + state.AddStake(150, "user1_deposit1", 200) + state.AddStake(200, "user1_deposit2", 300) + + reward1 := state.Claim("user1_deposit1", 250) + reward2 := state.Claim("user1_deposit2", 250) + + uassert.NotEqual(t, reward1, uint64(0), "First deposit should have rewards") + uassert.NotEqual(t, reward2, uint64(0), "Second deposit should have rewards") +} + +func TestRewardState_RewardCalculationAccuracy(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + state.AddStake(150, "deposit1", 100) + state.AddStake(150, "deposit2", 100) + + reward1 := state.Claim("deposit1", 200) + reward2 := state.Claim("deposit2", 200) + + // in a same condition, the same reward should be paid + uassert.Equal(t, reward1, reward2, "Equal stakes should receive equal rewards") +} + +func TestRewardState_ZeroRewardPerBlock(t *testing.T) { + // Test with zero rewards per block + rewardPerBlock := u256.Zero() + state := NewRewardState(rewardPerBlock, 100, 1000) + + state.AddStake(150, "deposit1", 100) + reward := state.Claim("deposit1", 200) + + uassert.Equal(t, reward, uint64(0), + "Expected zero reward when rewardPerBlock is zero") +} + +func TestRewardState_ConsecutiveStakeAndUnstake(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + // Repeatedly stake and unstake + state.AddStake(150, "deposit1", 100) + state.RemoveStake("deposit1", 100, 160) + + state.AddStake(170, "deposit2", 200) + state.RemoveStake("deposit2", 200, 180) + + state.AddStake(190, "deposit3", 300) + + uassert.Equal(t, state.TotalStake, uint64(300), + "Final stake amount should be correct after multiple stake/unstake operations") +} + +func TestRewardState_ClaimAtEndHeight(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + startHeight := uint64(100) + endHeight := uint64(200) + state := NewRewardState(rewardPerBlock, startHeight, endHeight) + + state.AddStake(150, "deposit1", 100) + + // Claim exactly at endHeight + reward1 := state.Claim("deposit1", endHeight) + + // Claim after endHeight + reward2 := state.Claim("deposit1", endHeight+50) + + uassert.NotEqual(t, reward1, uint64(0), + "Should receive rewards when claiming at endHeight") + uassert.Equal(t, reward2, uint64(0), + "Should not receive additional rewards after endHeight") +} + +func TestRewardState_StakeChangesDuringRewardPeriod(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + // First staker + state.AddStake(150, "deposit1", 1000) + + // Second staker joins + state.AddStake(160, "deposit2", 1000) + + // First staker removes half + reward1 := state.RemoveStake("deposit1", 500, 170) + + // Second staker claims + reward2 := state.Claim("deposit2", 170) + + uassert.Equal(t, reward1, uint64(15000)) + uassert.Equal(t, reward2, uint64(5000)) +} + +func TestRewardState_SmallIntervals(t *testing.T) { + rewardPerBlock := u256.NewUint(1000).Lsh(u256.NewUint(1000), 128) + state := NewRewardState(rewardPerBlock, 100, 1000) + + // Test with consecutive block heights + state.AddStake(150, "deposit1", 100) + reward1 := state.Claim("deposit1", 151) + reward2 := state.Claim("deposit1", 152) + + uassert.NotEqual(t, reward1, uint64(0), + "Should receive rewards for single block interval") + + uassert.Equal(t, reward1, reward2) + uassert.Equal(t, reward1, uint64(1000)) +} diff --git a/contract/r/gnoswap/launchpad/reward_test.gno b/contract/r/gnoswap/launchpad/reward_test.gno new file mode 100644 index 000000000..7cedb3c85 --- /dev/null +++ b/contract/r/gnoswap/launchpad/reward_test.gno @@ -0,0 +1,369 @@ +package launchpad + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestValidateRewardCollection(t *testing.T) { + currentHeight := uint64(100) + + tests := []struct { + name string + deposit Deposit + height uint64 + shouldError bool + errorMsg string + }{ + { + name: "invalid deposit Id", + deposit: Deposit{ + claimableHeight: 50, // Less than current height + rewardCollectTime: 0, + }, + height: currentHeight, + shouldError: true, + errorMsg: "[GNOSWAP-LAUNCHPAD-020] not exist deposit || depositId()", + }, + { + name: "Reward not yet claimable", + deposit: Deposit{ + claimableHeight: 150, // Greater than current height + rewardCollectTime: 0, + }, + height: currentHeight, + shouldError: true, + errorMsg: "reward not yet claimable", + }, + { + name: "Reward already collected", + deposit: Deposit{ + claimableHeight: 50, // Less than current height + rewardCollectTime: 100, + }, + height: currentHeight, + shouldError: true, + errorMsg: "[GNOSWAP-LAUNCHPAD-023] already collected", + }, + { + name: "Valid case", + deposit: Deposit{ + id: "depositId", + depositor: "g1wymu47drhr0kuq2098m792lytgtj2nyx77yrsm", + claimableHeight: 50, // Less than current height + rewardCollectTime: 0, + }, + height: currentHeight, + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !tt.shouldError { + updateDeposit(tt.deposit) + } + err := validateRewardCollection(tt.deposit, tt.height) + if tt.shouldError { + uassert.Error(t, err) + } else { + uassert.NoError(t, err) + } + }) + } +} + +func TestCalculateDepositRatioX128(t *testing.T) { + tests := []struct { + name string + tierAmount uint64 + amount uint64 + expected string + }{ + { + name: "Equal amounts", + tierAmount: 1000, + amount: 1000, + expected: "340282366920938463463374607431768211456", // 1.0 in Q128 + }, + { + name: "Half amount", + tierAmount: 1000, + amount: 500, + expected: "170141183460469231731687303715884105728", // 0.5 in Q128 + }, + { + name: "Double amount", + tierAmount: 500, + amount: 1000, + expected: "680564733841876926926749214863536422912", // 2.0 in Q128 + }, + { + name: "Zero amount", + tierAmount: 1000, + amount: 0, + expected: "0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ratio := calcDepositRatioX128(tt.tierAmount, tt.amount) + expected := u256.MustFromDecimal(tt.expected) + uassert.Equal(t, 0, ratio.Cmp(expected)) + }) + } +} + +func TestProcessDepositReward(t *testing.T) { + tests := []struct { + name string + deposit Deposit + rewardX128 *u256.Uint + tierAmount uint64 + expectedReward uint64 + shouldError bool + }{ + { + name: "Normal reward calculation", + deposit: Deposit{ + amount: 1000, + //rewardAmount: 0, + }, + rewardX128: u256.NewUint(1000).Mul(u256.NewUint(1000), q128), + tierAmount: 2000, + expectedReward: 500, // (1000/2000) * 1000 = 500 + shouldError: false, + }, + { + name: "Zero tier amount", + deposit: Deposit{ + amount: 1000, + //rewardAmount: 0, + }, + rewardX128: u256.NewUint(1000).Mul(u256.NewUint(1000), q128), + tierAmount: 0, + expectedReward: 0, + shouldError: true, + }, + { + name: "Reward is accumulated", + deposit: Deposit{ + amount: 1000, + //rewardAmount: 100, + }, + rewardX128: u256.NewUint(1000).Mul(u256.NewUint(1000), q128), + tierAmount: 2000, + expectedReward: 500, // 600, // Existing 100 + New reward 500 + shouldError: false, + }, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + state := NewRewardState(tt.rewardX128, 1000, 2000) + depositId := ufmt.Sprintf("depositId-%d", i) + state.AddStake(1000, depositId, tt.deposit.amount) + state.TotalStake = tt.tierAmount // force overriding total stake + state.finalize(1001) + + reward := state.CalculateReward(depositId) + uassert.Equal(t, tt.expectedReward, reward) + }) + } +} + +func TestCalculateTierRewards(t *testing.T) { + tests := []struct { + name string + tier Tier + currentHeight uint64 + lastCalcHeight uint64 + expectedRewardX128 string + expectedReward uint64 + shouldBeZero bool + }{ + { + name: "tierAmountPerBlockX96 is 0", + tier: Tier{ + tierAmountPerBlockX128: u256.Zero(), + ended: TimeInfo{height: 1000}, + }, + currentHeight: 500, + lastCalcHeight: 400, + shouldBeZero: true, + }, + { + name: "sinceLast is 0", + tier: Tier{ + tierAmountPerBlockX128: u256.NewUint(1000), + ended: TimeInfo{height: 1000}, + }, + currentHeight: 500, + lastCalcHeight: 500, + shouldBeZero: true, + }, + { + name: "Current height exceeds end height", + tier: Tier{ + tierAmountPerBlockX128: u256.NewUint(1000), + ended: TimeInfo{height: 450}, + }, + currentHeight: 500, + lastCalcHeight: 400, + expectedRewardX128: "50000", // 1000 * (450-400) + expectedReward: 0, // (50000 / Q96) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rewardX128, reward, err := calculateTierRewards(tt.tier, tt.currentHeight, tt.lastCalcHeight) + + uassert.NoError(t, err) + + if tt.shouldBeZero { + uassert.Equal(t, true, rewardX128.IsZero()) + uassert.Equal(t, uint64(0), reward) + } else { + expected := u256.MustFromDecimal(tt.expectedRewardX128) + uassert.Equal(t, 0, rewardX128.Cmp(expected)) + uassert.Equal(t, tt.expectedReward, reward) + } + }) + } +} + +func TestCollectRewardByDepositId_FailCases(t *testing.T) { + setup := func() { + projects = make(map[string]Project) + deposits = make(map[string]Deposit) + depositsByUserByProject = make(map[std.Address]map[string][]string) + } + + t.Run("ID not found", func(t *testing.T) { + setup() + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic but not") + } + }() + + CollectRewardByDepositId("non_existent_id") + }) + + t.Run("Claimable height before", func(t *testing.T) { + setup() + + projectId := "test_project" + depositId := "test_deposit" + + deposits[depositId] = Deposit{ + id: depositId, + projectId: projectId, + claimableHeight: uint64(std.GetHeight()) + 1000, + } + + uassert.PanicsWithMessage(t, "[GNOSWAP-LAUNCHPAD-024] not yet claim reward || need block >= 1123, current=123", func() { + CollectRewardByDepositId(depositId) + }) + }) +} + +func TestCollectRewardByDepositId_OwnerCheck(t *testing.T) { + userA := testutils.TestAddress("userA") + userB := testutils.TestAddress("userB") + + project := createTestProject(t) + projects[project.id] = project + + depositA := Deposit{ + id: "depA", + projectId: project.id, + tier: "30", + depositor: userA, + claimableHeight: 133, + claimableTime: uint64(time.Now().Unix()) + 1000, + rewardCollectTime: 0, + } + + updateDeposit(depositA) + + depositB := Deposit{ + id: "depB", + projectId: project.id, + tier: "30", + depositor: userB, + claimableHeight: 10, + claimableTime: uint64(time.Now().Unix()) + 1000, + rewardCollectTime: 0, + } + + updateDeposit(depositB) + + t.Run("UserA tries depositB => panic", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(userA)) + defer func() { + r := recover() + if r == nil { + t.Errorf("expected panic, got none") + } else { + uassert.Equal(t, "[GNOSWAP-LAUNCHPAD-026] invalid owner || caller(g1w4ek2ujpta047h6lta047h6lta047h6l55c22z) does not own depositId(depB)", r.(string)) + } + }() + CollectRewardByDepositId("depB") + }) + + t.Run("UserA depositA => not claimable yet", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(userA)) + // height=0, depositA.claimableHeight=10 + defer func() { + r := recover() + if r == nil { + t.Errorf("expected panic for not claimable, got none") + } else { + uassert.Equal(t, "[GNOSWAP-LAUNCHPAD-024] not yet claim reward || need block >= 133, current=123", r.(string)) + } + }() + CollectRewardByDepositId("depA") + }) + + std.TestSkipHeights(11) // now height=11 >= depositA.claimableHeight=10 + + depA := deposits["depA"] + depA.claimableTime = 0 + deposits["depA"] = depA + + t.Run("UserA depositA => claim success", func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(userA)) + + rewardPerDeposit := u256.NewUint(1) + rewardPerDepositX128 := u256.Zero().Lsh(rewardPerDeposit, 128) + rewardPerBlock := u256.NewUint(100) + rewardPerBlock = u256.Zero().Lsh(rewardPerBlock, 128) + mockReward := NewReward(rewardPerDepositX128, 123, 10000) + mockReward.info.Set("depA", &RewardInfo{ + startHeight: 123, + priceDebt: u256.Zero(), + depositAmount: 100, + claimed: 0, + }) + mockTier := &Tier{ + tierAmountPerBlockX128: rewardPerBlock, + totalDepositAmount: 100, + ended: TimeInfo{height: 10000}, + } + mockTier.SetReward(*mockReward) + project.setTier(TIER30, *mockTier) + + amount := CollectRewardByDepositId("depA") + uassert.Equal(t, uint64(100), amount) + }) +} diff --git a/contract/r/gnoswap/launchpad/type.gno b/contract/r/gnoswap/launchpad/type.gno new file mode 100644 index 000000000..bb6c45c51 --- /dev/null +++ b/contract/r/gnoswap/launchpad/type.gno @@ -0,0 +1,817 @@ +package launchpad + +import ( + "errors" + "std" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/gov/xgns" +) + +type ProjectStats struct { + totalDeposit uint64 // won't be decreased + actualDeposit uint64 // will be decreased if deposit collected 'CollectDepositGns()' + totalParticipant uint64 // accu, won't be decreased + actualParticipant uint64 // will be decreased if deposit collected 'CollectDepositGns()' + totalCollected uint64 // collect reward amount +} + +// NewProjectStats returns a pointer to a new ProjectStats with the given values. +func NewProjectStats( + totalDeposit, actualDeposit, totalParticipant, actualParticipant, totalCollected uint64, +) *ProjectStats { + return &ProjectStats{ + totalDeposit: totalDeposit, + actualDeposit: actualDeposit, + totalParticipant: totalParticipant, + actualParticipant: actualParticipant, + totalCollected: totalCollected, + } +} + +func (ps *ProjectStats) TotalDeposit() uint64 { + return ps.totalDeposit +} +func (ps *ProjectStats) SetTotalDeposit(v uint64) { + ps.totalDeposit = v +} +func (ps *ProjectStats) ActualDeposit() uint64 { + return ps.actualDeposit +} +func (ps *ProjectStats) SetActualDeposit(v uint64) { + ps.actualDeposit = v +} +func (ps *ProjectStats) TotalParticipant() uint64 { + return ps.totalParticipant +} +func (ps *ProjectStats) SetTotalParticipant(v uint64) { + ps.totalParticipant = v +} +func (ps *ProjectStats) ActualParticipant() uint64 { + return ps.actualParticipant +} +func (ps *ProjectStats) SetActualParticipant(v uint64) { + ps.actualParticipant = v +} +func (ps *ProjectStats) TotalCollected() uint64 { + return ps.totalCollected +} +func (ps *ProjectStats) SetTotalCollected(v uint64) { + ps.totalCollected = v +} + +type Project struct { + id string // 'tokenPath:createdHeight' + name string + tokenPath string + depositAmount uint64 + recipient std.Address // string + conditions map[string]Condition // tokenPath -> Condition + tiers map[uint64]Tier + tiersRatios map[uint64]uint64 + created TimeInfo + started TimeInfo + ended TimeInfo // same with tier 180's data + stats ProjectStats + refund RefundInfo +} + +// NewProject returns a pointer to a new Project with the given values. +func NewProject( + id, name, tokenPath string, + depositAmount uint64, + recipient std.Address, + conditions map[string]Condition, + tiers map[uint64]Tier, + tiersRatios map[uint64]uint64, + created, started, ended TimeInfo, + stats ProjectStats, + refund RefundInfo, +) *Project { + return &Project{ + id: id, + name: name, + tokenPath: tokenPath, + depositAmount: depositAmount, + recipient: recipient, + conditions: conditions, + tiers: tiers, + tiersRatios: tiersRatios, + created: created, + started: started, + ended: ended, + stats: stats, + refund: refund, + } +} + +func (p *Project) ID() string { + return p.id +} +func (p *Project) SetID(v string) { + p.id = v +} +func (p *Project) Name() string { + return p.name +} +func (p *Project) SetName(v string) { + p.name = v +} +func (p *Project) TokenPath() string { + return p.tokenPath +} +func (p *Project) SetTokenPath(v string) { + p.tokenPath = v +} +func (p *Project) DepositAmount() uint64 { + return p.depositAmount +} +func (p *Project) SetDepositAmount(v uint64) { + p.depositAmount = v +} +func (p *Project) Recipient() std.Address { + return p.recipient +} +func (p *Project) SetRecipient(v std.Address) { + p.recipient = v +} +func (p *Project) Conditions() map[string]Condition { + return p.conditions +} +func (p *Project) SetConditions(v map[string]Condition) { + p.conditions = v +} +func (p *Project) Tiers() map[uint64]Tier { + return p.tiers +} +func (p *Project) SetTiers(v map[uint64]Tier) { + p.tiers = v +} +func (p *Project) TiersRatios() map[uint64]uint64 { + return p.tiersRatios +} +func (p *Project) SetTiersRatios(v map[uint64]uint64) { + p.tiersRatios = v +} +func (p *Project) Created() TimeInfo { + return p.created +} +func (p *Project) SetCreated(v TimeInfo) { + p.created = v +} +func (p *Project) Started() TimeInfo { + return p.started +} +func (p *Project) SetStarted(v TimeInfo) { + p.started = v +} +func (p *Project) Ended() TimeInfo { + return p.ended +} +func (p *Project) SetEnded(v TimeInfo) { + p.ended = v +} +func (p *Project) Stats() *ProjectStats { + return &p.stats +} +func (p *Project) SetStats(v ProjectStats) { + p.stats = v +} +func (p *Project) Refund() RefundInfo { + return p.refund +} +func (p *Project) SetRefund(v RefundInfo) { + p.refund = v +} +func (p *Project) Tier(key uint64) (Tier, error) { + tier, exists := p.tiers[key] + if !exists { + return Tier{}, errors.New(addDetailToError( + errDataNotFound, ufmt.Sprintf("tier(%s) not found", key))) + } + return tier, nil +} +func (p *Project) isActivated(currentHeight uint64) bool { + return p.started.Height() <= currentHeight && currentHeight < p.ended.Height() +} +func (p *Project) checkDepositConditions(caller std.Address) { + conditions := p.Conditions() + if conditions == nil { + panic(addDetailToError( + errInvalidData, "conditions is nil")) + } + + for _, condition := range conditions { + // xGNS(or GNS) may have a zero condition + if condition.MinAmount() == 0 { + continue + } + + var balance uint64 + if condition.TokenPath() == consts.GOV_XGNS_PATH { + balance = xgns.BalanceOf(caller) + } else { + balance = common.BalanceOf(condition.TokenPath(), caller) + } + + if balance < condition.MinAmount() { + panic(addDetailToError( + errNotEnoughBalance, ufmt.Sprintf("insufficient balance(%d) for condition token(%s)", balance, condition.tokenPath))) + } + } +} +func (p *Project) setTier(key uint64, tier Tier) { + p.tiers[key] = tier +} +func (p *Project) calculateLeftReward() uint64 { + return p.DepositAmount() - p.Stats().totalCollected +} + +type Tier struct { + id string // '{projectId}:duration' // duartion == 30, 90, 180 + collectWaitDuration uint64 // block + tierAmount uint64 + tierAmountPerBlockX128 *u256.Uint + started TimeInfo // first deposit height + ended TimeInfo // + totalDepositAmount uint64 // accumulated deposit amount + actualDepositAmount uint64 // actual deposit amount + totalParticipant uint64 // accumulated participant + actualParticipant uint64 // actual participant + userCollectedAmount uint64 // total collected amount by user + calculatedAmount uint64 // total calculated amount + reward Reward +} + +// NewTier returns a pointer to a new Tier with the given values. +func NewTier( + id string, + collectWaitDuration uint64, + tierAmount uint64, + tierAmountPerBlockX128 *u256.Uint, + started, ended TimeInfo, + totalDepositAmount, actualDepositAmount uint64, + totalParticipant, actualParticipant uint64, + userCollectedAmount, calculatedAmount uint64, + reward Reward, +) *Tier { + return &Tier{ + id: id, + collectWaitDuration: collectWaitDuration, + tierAmount: tierAmount, + tierAmountPerBlockX128: tierAmountPerBlockX128, + started: started, + ended: ended, + totalDepositAmount: totalDepositAmount, + actualDepositAmount: actualDepositAmount, + totalParticipant: totalParticipant, + actualParticipant: actualParticipant, + userCollectedAmount: userCollectedAmount, + calculatedAmount: calculatedAmount, + reward: reward, + } +} + +func (t *Tier) ID() string { + return t.id +} +func (t *Tier) SetID(v string) { + t.id = v +} +func (t *Tier) CollectWaitDuration() uint64 { + return t.collectWaitDuration +} +func (t *Tier) SetCollectWaitDuration(v uint64) { + t.collectWaitDuration = v +} +func (t *Tier) TierAmount() uint64 { + return t.tierAmount +} +func (t *Tier) SetTierAmount(v uint64) { + t.tierAmount = v +} +func (t *Tier) TierAmountPerBlockX128() *u256.Uint { + return t.tierAmountPerBlockX128 +} +func (t *Tier) SetTierAmountPerBlockX128(v *u256.Uint) { + t.tierAmountPerBlockX128 = v.Clone() +} +func (t *Tier) Started() TimeInfo { + return t.started +} +func (t *Tier) SetStarted(v TimeInfo) { + t.started = v +} +func (t *Tier) Ended() TimeInfo { + return t.ended +} +func (t *Tier) SetEnded(v TimeInfo) { + t.ended = v +} +func (t *Tier) TotalDepositAmount() uint64 { + return t.totalDepositAmount +} +func (t *Tier) SetTotalDepositAmount(v uint64) { + t.totalDepositAmount = v +} +func (t *Tier) ActualDepositAmount() uint64 { + return t.actualDepositAmount +} +func (t *Tier) SetActualDepositAmount(v uint64) { + t.actualDepositAmount = v +} +func (t *Tier) TotalParticipant() uint64 { + return t.totalParticipant +} +func (t *Tier) SetTotalParticipant(v uint64) { + t.totalParticipant = v +} +func (t *Tier) ActualParticipant() uint64 { + return t.actualParticipant +} +func (t *Tier) SetActualParticipant(v uint64) { + t.actualParticipant = v +} +func (t *Tier) UserCollectedAmount() uint64 { + return t.userCollectedAmount +} +func (t *Tier) SetUserCollectedAmount(v uint64) { + t.userCollectedAmount = v +} +func (t *Tier) CalculatedAmount() uint64 { + return t.calculatedAmount +} +func (t *Tier) SetCalculatedAmount(v uint64) { + t.calculatedAmount = v +} +func (t *Tier) Reward() Reward { + return t.reward +} +func (t *Tier) SetReward(v Reward) { + t.reward = v +} +func (t *Tier) isActivated(currentHeight uint64) bool { + return t.started.Height() <= currentHeight && currentHeight < t.ended.Height() +} +func (t *Tier) isFirstDeposit() bool { + return t.totalParticipant == 0 +} +func (t *Tier) updateStarted(height, time uint64) { + t.started.SetHeight(height) + t.started.SetTime(time) + t.reward.SetLastHeight(height) +} +func (t *Tier) rewardPerBlockUint64() uint64 { + return u256.Zero().Rsh(t.tierAmountPerBlockX128, 128).Uint64() +} +func (t *Tier) isEnded(currentHeight uint64) bool { + return t.ended.Height() < currentHeight +} +func (t *Tier) calculateLeftReward() uint64 { + return t.tierAmount - t.userCollectedAmount +} + +type RewardInfo struct { + priceDebt *u256.Uint // price debt per GNS stake, Q128 + depositAmount uint64 // amount of GNS staked + rewardAmount uint64 // calculated, not collected + claimed uint64 // amount of reward claimed so far + startHeight uint64 // height when launchpad started staking + EndHeight uint64 // end height of reward calculation + LastHeight uint64 // last height when reward was calculated +} + +// NewRewardInfo returns a pointer to a new RewardInfo with the given values. +func NewRewardInfo( + priceDebt *u256.Uint, + depositAmount, rewardAmount, claimed, startHeight, endHeight, lastHeight uint64, +) *RewardInfo { + return &RewardInfo{ + priceDebt: priceDebt, + depositAmount: depositAmount, + rewardAmount: rewardAmount, + claimed: claimed, + startHeight: startHeight, + EndHeight: endHeight, + LastHeight: lastHeight, + } +} + +func (r *RewardInfo) PriceDebt() *u256.Uint { + return r.priceDebt +} +func (r *RewardInfo) SetPriceDebt(v *u256.Uint) { + r.priceDebt = v +} +func (r *RewardInfo) DepositAmount() uint64 { + return r.depositAmount +} +func (r *RewardInfo) SetDepositAmount(v uint64) { + r.depositAmount = v +} +func (r *RewardInfo) RewardAmount() uint64 { + return r.rewardAmount +} +func (r *RewardInfo) SetRewardAmount(v uint64) { + r.rewardAmount = v +} +func (r *RewardInfo) Claimed() uint64 { + return r.claimed +} +func (r *RewardInfo) SetClaimed(v uint64) { + r.claimed = v +} +func (r *RewardInfo) StartHeight() uint64 { + return r.startHeight +} +func (r *RewardInfo) SetStartHeight(v uint64) { + r.startHeight = v +} +func (r *RewardInfo) GetEndHeight() uint64 { + return r.EndHeight +} +func (r *RewardInfo) SetEndHeight(v uint64) { + r.EndHeight = v +} +func (r *RewardInfo) GetLastHeight() uint64 { + return r.LastHeight +} +func (r *RewardInfo) SetLastHeight(v uint64) { + r.LastHeight = v +} +func (r *RewardInfo) calculateReward(accumRewardPerDeposit *u256.Uint) uint64 { + actualRewardPerDeposit := u256.Zero().Sub(accumRewardPerDeposit, r.PriceDebt()) + reward := u256.Zero().Mul(actualRewardPerDeposit, u256.NewUint(r.DepositAmount())) + reward = u256.Zero().Rsh(reward, 128) + return reward.Uint64() - r.Claimed() +} + +type Reward struct { + accumRewardPerDeposit *u256.Uint // claimable Launchpad reward per GNS stake, Q128 + LastHeight uint64 // last height when reward was calculated + EndHeight uint64 // end height of reward calculation + info *avl.Tree // depositId -> RewardInfo +} + +// NewReward returns a pointer to a new Reward with the given values. +func NewReward( + accumRewardPerDeposit *u256.Uint, + lastHeight, endHeight uint64, +) *Reward { + return &Reward{ + accumRewardPerDeposit: accumRewardPerDeposit, + LastHeight: lastHeight, + EndHeight: endHeight, + info: avl.NewTree(), + } +} + +func (r *Reward) AccumRewardPerDeposit() *u256.Uint { + return r.accumRewardPerDeposit +} +func (r *Reward) SetAccumRewardPerDeposit(v *u256.Uint) { + r.accumRewardPerDeposit = v +} +func (r *Reward) GetLastHeight() uint64 { + return r.LastHeight +} +func (r *Reward) SetLastHeight(v uint64) { + r.LastHeight = v +} +func (r *Reward) GetEndHeight() uint64 { + return r.EndHeight +} +func (r *Reward) SetEndHeight(v uint64) { + r.EndHeight = v +} +func (r *Reward) Info() *avl.Tree { + return r.info +} +func (r *Reward) SetInfo(v *avl.Tree) { + r.info = v +} +func (r *Reward) accumulationRewardUint64() uint64 { + return u256.Zero().Rsh(r.accumRewardPerDeposit, 128).Uint64() +} +func (r *Reward) InfoOf(depositId string) *RewardInfo { + infoI, exists := r.info.Get(depositId) + if !exists { + panic(addDetailToError( + errNotExistDeposit, ufmt.Sprintf("(%s)", depositId))) + } + + return infoI.(*RewardInfo) +} +func (r *Reward) calculateRewardPerDeposit(rewardPerBlock *u256.Uint, totalStaked uint64) (*u256.Uint, error) { + // blockDuration * rewardPerBlock / totalStaked + currentHeight := uint64(std.GetHeight()) + lastUpdateHeight := r.GetLastHeight() + if currentHeight == lastUpdateHeight { + return u256.Zero(), nil + } + if currentHeight < lastUpdateHeight { + return nil, errors.New(addDetailToError( + errInvalidHeight, ufmt.Sprintf("currentHeight(%d) <= lastUpdateHeight(%d)", currentHeight, lastUpdateHeight))) + } + if rewardPerBlock.IsZero() { + return nil, errors.New(addDetailToError( + errNoLeftReward, ufmt.Sprintf("rewardPerBlock(%d)", rewardPerBlock))) + } + if totalStaked == 0 { + return u256.Zero(), nil + } + if currentHeight > r.EndHeight { + currentHeight = r.EndHeight + } + + blockDuration := currentHeight - lastUpdateHeight + totalReward := u256.Zero().Mul(u256.NewUint(blockDuration), rewardPerBlock) + rewardPerDeposit := u256.Zero().Div(totalReward, u256.NewUint(totalStaked)) + return rewardPerDeposit, nil +} + +func (r *Reward) addRewardPerDeposit(rewardPerDeposit *u256.Uint) { + if rewardPerDeposit.IsZero() { + return + } + r.SetAccumRewardPerDeposit(u256.Zero().Add(r.AccumRewardPerDeposit(), rewardPerDeposit)) +} + +func (r *Reward) finalize(currentHeight, totalStaked uint64, rewardPerBlock *u256.Uint) (*u256.Uint, error) { + if currentHeight < r.LastHeight { + // Not started yet + return nil, errors.New(addDetailToError( + errInvalidHeight, ufmt.Sprintf("currentHeight(%d) <= LastHeight(%d)", currentHeight, r.LastHeight))) + } + if r.LastHeight > r.EndHeight { + // already ended + return nil, errors.New(addDetailToError( + errInvalidHeight, ufmt.Sprintf("LastHeight(%d) >= EndHeight(%d)", r.LastHeight, r.EndHeight))) + } + if currentHeight > r.EndHeight { + currentHeight = r.EndHeight + } + return r.calculateRewardPerDeposit(rewardPerBlock, totalStaked) +} +func (r *Reward) deductReward(depositId string, currentHeight uint64) uint64 { + if currentHeight < r.LastHeight { + panic(addDetailToError( + errInvalidRewardState, + ufmt.Sprintf("currentHeight %d is less than LastHeight %d", currentHeight, r.LastHeight))) + } + + deposit := deposits[depositId] + if currentHeight < deposit.claimableHeight { + panic(addDetailToError( + errInvalidRewardState, + ufmt.Sprintf("currentHeight %d is less than claimableHeight %d", currentHeight, deposit.claimableHeight))) + } + + info := r.InfoOf(depositId) + if currentHeight < info.StartHeight() { + panic(addDetailToError( + errInvalidRewardState, + ufmt.Sprintf("currentHeight %d is less than StartHeight %d", currentHeight, info.StartHeight()))) + } + reward64 := info.calculateReward(r.AccumRewardPerDeposit()) + + if reward64 == 0 { + return 0 + } + info.SetClaimed(info.Claimed() + reward64) + r.info.Set(depositId, info) + + return reward64 +} + +type Deposit struct { + id string // 'projectId:tier:depositor:height' + projectId string + tier string // 30, 60, 180 // instead of tierId + depositor std.Address + amount uint64 + depositHeight uint64 + depositTime uint64 + depositCollectHeight uint64 // withdraw deposited gns height + depositCollectTime uint64 // withdraw deposited gns timestamp + claimableHeight uint64 // claimable reward block height + claimableTime uint64 // claimable reward timestamp + rewardAmount uint64 // calculated, not collected + rewardCollected uint64 // accu, collected + rewardCollectHeight uint64 // last collected height + rewardCollectTime uint64 // last collected time +} + +// NewDeposit returns a pointer to a new Deposit with the given values. +func NewDeposit( + id, projectId, tier string, + depositor std.Address, + amount, depositHeight, depositTime uint64, + depositCollectHeight, depositCollectTime uint64, + claimableHeight, claimableTime uint64, + rewardAmount, rewardCollected uint64, + rewardCollectHeight, rewardCollectTime uint64, +) *Deposit { + return &Deposit{ + id: id, + projectId: projectId, + tier: tier, + depositor: depositor, + amount: amount, + depositHeight: depositHeight, + depositTime: depositTime, + depositCollectHeight: depositCollectHeight, + depositCollectTime: depositCollectTime, + claimableHeight: claimableHeight, + claimableTime: claimableTime, + rewardAmount: rewardAmount, + rewardCollected: rewardCollected, + rewardCollectHeight: rewardCollectHeight, + rewardCollectTime: rewardCollectTime, + } +} + +func (d *Deposit) ID() string { + return d.id +} +func (d *Deposit) SetID(v string) { + d.id = v +} +func (d *Deposit) ProjectID() string { + return d.projectId +} +func (d *Deposit) SetProjectID(v string) { + d.projectId = v +} +func (d *Deposit) Tier() string { + return d.tier +} +func (d *Deposit) SetTier(v string) { + d.tier = v +} +func (d *Deposit) Depositor() std.Address { + return d.depositor +} +func (d *Deposit) SetDepositor(v std.Address) { + d.depositor = v +} +func (d *Deposit) Amount() uint64 { + return d.amount +} +func (d *Deposit) SetAmount(v uint64) { + d.amount = v +} +func (d *Deposit) DepositHeight() uint64 { + return d.depositHeight +} +func (d *Deposit) SetDepositHeight(v uint64) { + d.depositHeight = v +} +func (d *Deposit) DepositTime() uint64 { + return d.depositTime +} +func (d *Deposit) SetDepositTime(v uint64) { + d.depositTime = v +} +func (d *Deposit) DepositCollectHeight() uint64 { + return d.depositCollectHeight +} +func (d *Deposit) SetDepositCollectHeight(v uint64) { + d.depositCollectHeight = v +} +func (d *Deposit) DepositCollectTime() uint64 { + return d.depositCollectTime +} +func (d *Deposit) SetDepositCollectTime(v uint64) { + d.depositCollectTime = v +} +func (d *Deposit) ClaimableHeight() uint64 { + return d.claimableHeight +} +func (d *Deposit) SetClaimableHeight(v uint64) { + d.claimableHeight = v +} +func (d *Deposit) ClaimableTime() uint64 { + return d.claimableTime +} +func (d *Deposit) SetClaimableTime(v uint64) { + d.claimableTime = v +} +func (d *Deposit) RewardAmount() uint64 { + return d.rewardAmount +} +func (d *Deposit) SetRewardAmount(v uint64) { + d.rewardAmount = v +} +func (d *Deposit) RewardCollected() uint64 { + return d.rewardCollected +} +func (d *Deposit) SetRewardCollected(v uint64) { + d.rewardCollected = v +} +func (d *Deposit) RewardCollectHeight() uint64 { + return d.rewardCollectHeight +} +func (d *Deposit) SetRewardCollectHeight(v uint64) { + d.rewardCollectHeight = v +} +func (d *Deposit) RewardCollectTime() uint64 { + return d.rewardCollectTime +} +func (d *Deposit) SetRewardCollectTime(v uint64) { + d.rewardCollectTime = v +} +func (d *Deposit) isClaimable(currentHeight uint64) bool { + return d.claimableHeight <= currentHeight +} + +type TimeInfo struct { + height uint64 + time uint64 +} + +// NewTimeInfo returns a pointer to a new TimeInfo with the given values. +func NewTimeInfo(height, t uint64) TimeInfo { + return TimeInfo{ + height: height, + time: t, + } +} + +func (ti *TimeInfo) Height() uint64 { + return ti.height +} +func (ti *TimeInfo) SetHeight(v uint64) { + ti.height = v +} +func (ti *TimeInfo) Time() uint64 { + return ti.time +} +func (ti *TimeInfo) SetTime(v uint64) { + ti.time = v +} + +type RefundInfo struct { + amount uint64 + height uint64 + time uint64 +} + +// NewRefundInfo returns a pointer to a new RefundInfo with the given values. +func NewRefundInfo(amount, height, t uint64) RefundInfo { + return RefundInfo{ + amount: amount, + height: height, + time: t, + } +} + +func (ri *RefundInfo) Amount() uint64 { + return ri.amount +} +func (ri *RefundInfo) SetAmount(v uint64) { + ri.amount = v +} +func (ri *RefundInfo) Height() uint64 { + return ri.height +} +func (ri *RefundInfo) SetHeight(v uint64) { + ri.height = v +} +func (ri *RefundInfo) Time() uint64 { + return ri.time +} +func (ri *RefundInfo) SetTime(v uint64) { + ri.time = v +} + +type Condition struct { + tokenPath string + minAmount uint64 +} + +// NewCondition returns a pointer to a new Condition with the given values. +func NewCondition(tokenPath string, minAmount uint64) Condition { + return Condition{ + tokenPath: tokenPath, + minAmount: minAmount, + } +} + +func (c *Condition) TokenPath() string { + return c.tokenPath +} +func (c *Condition) SetTokenPath(v string) { + c.tokenPath = v +} +func (c *Condition) MinAmount() uint64 { + return c.minAmount +} +func (c *Condition) SetMinAmount(v uint64) { + c.minAmount = v +} diff --git a/contract/r/gnoswap/launchpad/util.gno b/contract/r/gnoswap/launchpad/util.gno new file mode 100644 index 000000000..5f4525e8a --- /dev/null +++ b/contract/r/gnoswap/launchpad/util.gno @@ -0,0 +1,168 @@ +package launchpad + +import ( + "std" + "strconv" + + "gno.land/p/demo/json" + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/gns" +) + +func strToInt(str string) int { + res, err := strconv.Atoi(str) + if err != nil { + panic(err.Error()) + } + + return res +} + +func strToU256U64(str string) uint64 { + strValue := u256.MustFromDecimal(str) + return strValue.Uint64() +} + +func contains(slice []string, str string) bool { + for _, v := range slice { + if v == str { + return true + } + } + return false +} + +func marshal(data *json.Node) string { + b, err := json.Marshal(data) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func minU64(x, y uint64) uint64 { + if x < y { + return x + } + return y +} + +// containsString returns true if the string is in the slice. +func containsString(arr []string, val string) bool { + for _, v := range arr { + if v == val { + return true + } + } + return false +} + +// GetOrigPkgAddr returns the original package address. +// In position contract, original package address is the position address. +func GetOrigPkgAddr() std.Address { + return consts.LAUNCHPAD_ADDR +} + +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// getPrev returns the address and package path of the previous realm. +func getPrev() (string, string) { + prev := std.PrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrevPkgPath returns the package path of the previous realm. +func getPrevPkgPath() string { + return std.PrevRealm().PkgPath() +} + +// formatUint returns the string representation of the uint64 value. +func formatUint(v interface{}) string { + switch v := v.(type) { + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +// formatInt returns the string representation of the int64 value. +func formatInt(value int64) string { + return strconv.FormatInt(value, 10) +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertOnlyAdmin panics if the caller is not an admin. +func assertOnlyAdmin() { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(addDetailToError(errInvalidAddress, err.Error())) + } +} + +// assertOnlyValidAddress panics if the address is invalid. +func assertOnlyValidAddress(addr std.Address) { + if !addr.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("(%s)", addr), + )) + } +} + +// assertValidAmount panics if the amount is zero. +func assertValidAmount(amount uint64) { + if amount < MINIMUM_DEPOSIT_AMOUNT { + panic(addDetailToError( + errInvalidAmount, ufmt.Sprintf("amount(%d) should greater than minimum deposit amount(%d)", amount, MINIMUM_DEPOSIT_AMOUNT))) + } + if (amount % MINIMUM_DEPOSIT_AMOUNT) != 0 { + panic(addDetailToError( + errInvalidAmount, ufmt.Sprintf("amount(%d) must be a multiple of 1_000_000", amount))) + } +} + +// assertValidProjectId panics if the project ID is invalid. +func assertValidProjectId(projectId string) { + if projectId == "" { + panic(addDetailToError( + errInvalidProjectId, ufmt.Sprintf("(%s)", projectId))) + } +} + +// convertTimeToHeight converts the given timestamp to block height based on the average block time. +func convertTimeToHeight(timestamp uint64) uint64 { + avgBlockTimeMs := uint64(gns.GetAvgBlockTimeInMs()) + if avgBlockTimeMs == 0 { + panic(addDetailToError( + errInvalidAvgBlockTime, + ufmt.Sprintf("%d", avgBlockTimeMs))) + } + return timestamp * 1000 / avgBlockTimeMs +} diff --git a/contract/r/gnoswap/pool/_helper_test.gno b/contract/r/gnoswap/pool/_helper_test.gno new file mode 100644 index 000000000..ce2a8063f --- /dev/null +++ b/contract/r/gnoswap/pool/_helper_test.gno @@ -0,0 +1,611 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + "gno.land/r/onbloc/foo" + "gno.land/r/onbloc/obl" + "gno.land/r/onbloc/qux" + "gno.land/r/onbloc/usdc" + + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + pn "gno.land/r/gnoswap/v1/position" +) + +const ( + ugnotDenom string = "ugnot" + ugnotPath string = "ugnot" + wugnotPath string = "gno.land/r/demo/wugnot" + gnsPath string = "gno.land/r/gnoswap/v1/gns" + barPath string = "gno.land/r/onbloc/bar" + bazPath string = "gno.land/r/onbloc/baz" + fooPath string = "gno.land/r/onbloc/foo" + oblPath string = "gno.land/r/onbloc/obl" + quxPath string = "gno.land/r/onbloc/qux" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + maxApprove uint64 = 18446744073709551615 + max_timeout int64 = 9999999999 + + maxSqrtPriceLimitX96 string = "1461446703485210103287273052203988822378723970341" +) + +const ( + // define addresses to use in tests + addr01 = testutils.TestAddress("addr01") + addr02 = testutils.TestAddress("addr02") +) + +var ( + fooToken = common.GetToken(fooPath) + barToken = common.GetToken(barPath) + bazToken = common.GetToken(bazPath) +) + +var ( + admin = consts.ADMIN + alice = testutils.TestAddress("alice") + pool = consts.POOL_ADDR + protocolFee = consts.PROTOCOL_FEE_ADDR + router = consts.ROUTER_ADDR + adminRealm = std.NewUserRealm(admin) + posRealm = std.NewCodeRealm(consts.POSITION_PATH) + rouRealm = std.NewCodeRealm(consts.ROUTER_PATH) + + // addresses used in tests + addrUsedInTest = []std.Address{addr01, addr02} +) + +func InitialisePoolTest(t *testing.T) { + t.Helper() + + ugnotFaucet(t, admin, 100_000_000_000_000) + ugnotDeposit(t, admin, 100_000_000_000_000) + + std.TestSetOrigCaller(admin) + TokenApprove(t, gnsPath, admin, pool, maxApprove) + poolPath := GetPoolPath(wugnotPath, gnsPath, fee3000) + if !DoesPoolPathExist(poolPath) { + CreatePool(wugnotPath, gnsPath, fee3000, "79228162514264337593543950336") + } + + //2. create position + std.TestSetOrigCaller(alice) + TokenFaucet(t, wugnotPath, alice) + TokenFaucet(t, gnsPath, alice) + TokenApprove(t, wugnotPath, alice, pool, uint64(1000)) + TokenApprove(t, gnsPath, alice, pool, uint64(1000)) + MintPosition(t, + wugnotPath, + gnsPath, + fee3000, + int32(1020), + int32(5040), + "1000", + "1000", + "0", + "0", + max_timeout, + alice, + alice, + ) +} + +func TokenFaucet(t *testing.T, tokenPath string, to std.Address) { + t.Helper() + std.TestSetOrigCaller(admin) + defaultAmount := uint64(5_000_000_000) + + switch tokenPath { + case wugnotPath: + wugnotTransfer(t, to, defaultAmount) + case gnsPath: + gnsTransfer(t, to, defaultAmount) + case barPath: + barTransfer(t, to, defaultAmount) + case bazPath: + bazTransfer(t, to, defaultAmount) + case fooPath: + fooTransfer(t, to, defaultAmount) + case oblPath: + oblTransfer(t, to, defaultAmount) + case quxPath: + quxTransfer(t, to, defaultAmount) + default: + panic("token not found") + } +} + +func TokenBalance(t *testing.T, tokenPath string, owner std.Address) uint64 { + t.Helper() + switch tokenPath { + case wugnotPath: + return wugnot.BalanceOf(common.AddrToUser(owner)) + case gnsPath: + return gns.BalanceOf(owner) + case barPath: + return bar.BalanceOf(owner) + case bazPath: + return baz.BalanceOf(owner) + case fooPath: + return foo.BalanceOf(owner) + case oblPath: + return obl.BalanceOf(owner) + case quxPath: + return qux.BalanceOf(owner) + default: + panic("token not found") + } +} + +func TokenAllowance(t *testing.T, tokenPath string, owner, spender std.Address) uint64 { + t.Helper() + switch tokenPath { + case wugnotPath: + return wugnot.Allowance(common.AddrToUser(owner), common.AddrToUser(spender)) + case gnsPath: + return gns.Allowance(owner, spender) + case barPath: + return bar.Allowance(owner, spender) + case bazPath: + return baz.Allowance(owner, spender) + case fooPath: + return foo.Allowance(owner, spender) + case oblPath: + return obl.Allowance(owner, spender) + case quxPath: + return qux.Allowance(owner, spender) + default: + panic("token not found") + } +} + +func TokenApprove(t *testing.T, tokenPath string, owner, spender std.Address, amount uint64) { + t.Helper() + switch tokenPath { + case wugnotPath: + wugnotApprove(t, owner, spender, amount) + case gnsPath: + gnsApprove(t, owner, spender, amount) + case barPath: + barApprove(t, owner, spender, amount) + case bazPath: + bazApprove(t, owner, spender, amount) + case fooPath: + fooApprove(t, owner, spender, amount) + case oblPath: + oblApprove(t, owner, spender, amount) + case quxPath: + quxApprove(t, owner, spender, amount) + default: + panic("token not found") + } +} + +func MintPosition(t *testing.T, + token0 string, + token1 string, + fee uint32, + tickLower int32, + tickUpper int32, + amount0Desired string, // *u256.Uint + amount1Desired string, // *u256.Uint + amount0Min string, // *u256.Uint + amount1Min string, // *u256.Uint + deadline int64, + mintTo std.Address, + caller std.Address, +) (uint64, string, string, string) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(caller)) + + return pn.Mint( + token0, + token1, + fee, + tickLower, + tickUpper, + amount0Desired, + amount1Desired, + amount0Min, + amount1Min, + deadline, + mintTo, + caller) +} + +func MintPositionAll(t *testing.T, caller std.Address) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(caller)) + TokenApprove(t, gnsPath, caller, pool, maxApprove) + TokenApprove(t, gnsPath, caller, router, maxApprove) + TokenApprove(t, wugnotPath, caller, pool, maxApprove) + TokenApprove(t, wugnotPath, caller, router, maxApprove) + + params := []struct { + tickLower int32 + tickUpper int32 + liquidity uint64 + zeroToOne bool + }{ + { + tickLower: -300, + tickUpper: -240, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -240, + tickUpper: -180, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -180, + tickUpper: -120, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -120, + tickUpper: -60, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -60, + tickUpper: 0, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: 0, + tickUpper: 60, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 60, + tickUpper: 120, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 120, + tickUpper: 180, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 180, + tickUpper: 240, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 240, + tickUpper: 300, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: -360, + tickUpper: -300, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -420, + tickUpper: -360, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -480, + tickUpper: -420, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -540, + tickUpper: -480, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -600, + tickUpper: -540, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: 300, + tickUpper: 360, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 360, + tickUpper: 420, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 420, + tickUpper: 480, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 480, + tickUpper: 540, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 540, + tickUpper: 600, + liquidity: 10, + zeroToOne: false, + }, + } + + for _, p := range params { + MintPosition(t, + wugnotPath, + gnsPath, + fee3000, + p.tickLower, + p.tickUpper, + "100", + "100", + "0", + "0", + max_timeout, + caller, + caller) + } + +} + +func wugnotApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + wugnot.Approve(common.AddrToUser(spender), amount) +} + +func gnsApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + gns.Approve(spender, amount) +} + +func barApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + bar.Approve(spender, amount) +} + +func bazApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + baz.Approve(spender, amount) +} + +func fooApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + foo.Approve(spender, amount) +} + +func oblApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + obl.Approve(spender, amount) +} + +func quxApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + qux.Approve(spender, amount) +} + +func wugnotTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + wugnot.Transfer(common.AddrToUser(to), amount) +} + +func gnsTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + gns.Transfer(to, amount) +} + +func barTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + bar.Transfer(to, amount) +} + +func bazTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + baz.Transfer(to, amount) +} + +func fooTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + foo.Transfer(to, amount) +} + +func oblTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + obl.Transfer(to, amount) +} + +func quxTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + qux.Transfer(to, amount) +} + +// ---------------------------------------------------------------------------- +// ugnot + +func ugnotTransfer(t *testing.T, from, to std.Address, amount uint64) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(from)) + std.TestSetOrigSend(std.Coins{{ugnotDenom, int64(amount)}}, nil) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(from, to, std.Coins{{ugnotDenom, int64(amount)}}) +} + +func ugnotBalanceOf(t *testing.T, addr std.Address) uint64 { + t.Helper() + + banker := std.GetBanker(std.BankerTypeRealmIssue) + coins := banker.GetCoins(addr) + if len(coins) == 0 { + return 0 + } + + return uint64(coins.AmountOf(ugnotDenom)) +} + +func ugnotMint(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.IssueCoin(addr, denom, amount) + std.TestIssueCoins(addr, std.Coins{{denom, int64(amount)}}) +} + +func ugnotBurn(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.RemoveCoin(addr, denom, amount) +} + +func ugnotFaucet(t *testing.T, to std.Address, amount uint64) { + t.Helper() + faucetAddress := admin + std.TestSetOrigCaller(faucetAddress) + + if ugnotBalanceOf(t, faucetAddress) < amount { + newCoins := std.Coins{{ugnotDenom, int64(amount)}} + ugnotMint(t, faucetAddress, newCoins[0].Denom, newCoins[0].Amount) + std.TestSetOrigSend(newCoins, nil) + } + ugnotTransfer(t, faucetAddress, to, amount) +} + +func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(addr)) + wugnotAddr := consts.WUGNOT_ADDR + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(addr, wugnotAddr, std.Coins{{ugnotDenom, int64(amount)}}) + wugnot.Deposit() +} + +// resetObject resets the object state(clear or make it default values) +func resetObject(t *testing.T) { + pools = avl.NewTree() + slot0FeeProtocol = 0 + poolCreationFee = 100_000_000 + withdrawalFeeBPS = 100 +} + +func burnTokens(t *testing.T) { + t.Helper() + + // burn tokens + for _, addr := range addrUsedInTest { + burnFoo(addr) + burnBar(addr) + burnBaz(addr) + burnQux(addr) + burnObl(addr) + burnUsdc(addr) + } +} + +func burnFoo(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + foo.Burn(addr, foo.BalanceOf(addr)) +} + +func burnBar(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + bar.Burn(addr, bar.BalanceOf(addr)) +} + +func burnBaz(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + baz.Burn(addr, baz.BalanceOf(addr)) +} + +func burnQux(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + qux.Burn(addr, qux.BalanceOf(addr)) +} + +func burnObl(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + obl.Burn(addr, obl.BalanceOf(addr)) +} + +func burnUsdc(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + usdc.Burn(addr, usdc.BalanceOf(addr)) +} + +func TestBeforeResetObject(t *testing.T) { + // make some data + pools = avl.NewTree() + pools.Set("gno.land/r/gnoswap/v1/gns:gno.land/r/onbloc/usdc", &Pool{ + token0Path: "gno.land/r/gnoswap/v1/gns", + token1Path: "gno.land/r/onbloc/usdc", + }) + + slot0FeeProtocol = 1 + poolCreationFee = 100_000_000 + withdrawalFeeBPS = 100 + + // transfer some tokens + std.TestSetRealm(std.NewUserRealm(admin)) + foo.Transfer(addr01, 100_000_000) + bar.Transfer(addr01, 100_000_000) + + uassert.Equal(t, foo.BalanceOf(addr01), uint64(100_000_000)) + uassert.Equal(t, bar.BalanceOf(addr01), uint64(100_000_000)) +} + +func TestResetObject(t *testing.T) { + resetObject(t) + uassert.Equal(t, pools.Size(), 0) + uassert.Equal(t, slot0FeeProtocol, uint8(0)) + uassert.Equal(t, poolCreationFee, uint64(100_000_000)) + uassert.Equal(t, withdrawalFeeBPS, uint64(100)) +} + +func TestBurnTokens(t *testing.T) { + burnTokens(t) + + uassert.Equal(t, foo.BalanceOf(addr01), uint64(0)) // 100_000_000 -> 0 + uassert.Equal(t, bar.BalanceOf(addr01), uint64(0)) // 100_000_000 -> 0 +} diff --git a/contract/r/gnoswap/pool/api.gno b/contract/r/gnoswap/pool/api.gno new file mode 100644 index 000000000..32c15a7db --- /dev/null +++ b/contract/r/gnoswap/pool/api.gno @@ -0,0 +1,325 @@ +package pool + +import ( + b64 "encoding/base64" + + "gno.land/p/demo/json" + u256 "gno.land/p/gnoswap/uint256" + + "std" + "strconv" + "strings" + "time" + + "gno.land/p/demo/ufmt" +) + +type RpcPool struct { + PoolPath string `json:"poolPath"` + + Token0Path string `json:"token0Path"` + Token1Path string `json:"token1Path"` + + Token0Balance string `json:"token0Balance"` + Token1Balance string `json:"token1Balance"` + + Fee uint32 `json:"fee"` + + TickSpacing int32 `json:"tickSpacing"` + + MaxLiquidityPerTick string `json:"maxLiquidityPerTick"` + + Slot0SqrtPriceX96 string `json:"sqrtPriceX96"` + Slot0Tick int32 `json:"tick"` + Slot0FeeProtocol uint8 `json:"feeProtocol"` + Slot0Unlocked bool `json:"unlocked"` + + FeeGrowthGlobal0X128 string `json:"feeGrowthGlobal0X128"` + FeeGrowthGlobal1X128 string `json:"feeGrowthGlobal1X128"` + + Token0ProtocolFee string `json:"token0ProtocolFee"` + Token1ProtocolFee string `json:"token1ProtocolFee"` + + Liquidity string `json:"liquidity"` + + Ticks RpcTicks `json:"ticks"` + + TickBitmaps RpcTickBitmaps `json:"tickBitmaps"` + + Positions []RpcPosition `json:"positions"` +} + +type RpcTickBitmaps map[int16]string // tick(wordPos) => bitmap(tickWord ^ mask) + +type RpcTicks map[int32]RpcTickInfo // tick => RpcTickInfo + +type RpcTickInfo struct { + LiquidityGross string `json:"liquidityGross"` + LiquidityNet string `json:"liquidityNet"` + + FeeGrowthOutside0X128 string `json:"feeGrowthOutside0X128"` + FeeGrowthOutside1X128 string `json:"feeGrowthOutside1X128"` + + TickCumulativeOutside int64 `json:"tickCumulativeOutside"` + + SecondsPerLiquidityOutsideX string `json:"secondsPerLiquidityOutsideX"` + SecondsOutside uint32 `json:"secondsOutside"` + + Initialized bool `json:"initialized"` +} + +type RpcPosition struct { + Owner string `json:"owner"` + + TickLower int32 `json:"tickLower"` + TickUpper int32 `json:"tickUpper"` + + Liquidity string `json:"liquidity"` + + Token0Owed string `json:"token0Owed"` + Token1Owed string `json:"token1Owed"` +} + +func ApiGetPools() string { + rpcPools := []RpcPool{} + pools.Iterate("", "", func(poolPath string, value interface{}) bool { + rpcPool := rpcMakePool(poolPath) + rpcPools = append(rpcPools, rpcPool) + + return false + }) + + responses := json.ArrayNode("", []*json.Node{}) + for _, pool := range rpcPools { + poolNode := json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", pool.PoolPath), + "token0Path": json.StringNode("token0Path", pool.Token0Path), + "token1Path": json.StringNode("token1Path", pool.Token1Path), + "token0Balance": json.StringNode("token0Balance", pool.Token0Balance), + "token1Balance": json.StringNode("token1Balance", pool.Token1Balance), + "fee": json.NumberNode("fee", float64(pool.Fee)), + "tickSpacing": json.NumberNode("tickSpacing", float64(pool.TickSpacing)), + "maxLiquidityPerTick": json.StringNode("maxLiquidityPerTick", pool.MaxLiquidityPerTick), + "sqrtPriceX96": json.StringNode("sqrtPriceX96", pool.Slot0SqrtPriceX96), + "tick": json.NumberNode("tick", float64(pool.Slot0Tick)), + "feeProtocol": json.NumberNode("feeProtocol", float64(pool.Slot0FeeProtocol)), + "unlocked": json.BoolNode("unlocked", pool.Slot0Unlocked), + "feeGrowthGlobal0X128": json.StringNode("feeGrowthGlobal0X128", pool.FeeGrowthGlobal0X128), + "feeGrowthGlobal1X128": json.StringNode("feeGrowthGlobal1X128", pool.FeeGrowthGlobal1X128), + "token0ProtocolFee": json.StringNode("token0ProtocolFee", pool.Token0ProtocolFee), + "token1ProtocolFee": json.StringNode("token1ProtocolFee", pool.Token1ProtocolFee), + "liquidity": json.StringNode("liquidity", pool.Liquidity), + "ticks": json.ObjectNode("ticks", makeTicksJson(pool.Ticks)), + "tickBitmaps": json.ObjectNode("tickBitmaps", makeRpcTickBitmapsJson(pool.TickBitmaps)), + "positions": json.ArrayNode("positions", makeRpcPositionsArray(pool.Positions)), + }) + responses.AppendArray(poolNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": makeStatNode(), + "response": responses, + }) + + return marshal(node) +} + +func ApiGetPool(poolPath string) string { + if !pools.Has(poolPath) { + return "" + } + rpcPool := rpcMakePool(poolPath) + + responseNode := json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", rpcPool.PoolPath), + "token0Path": json.StringNode("token0Path", rpcPool.Token0Path), + "token1Path": json.StringNode("token1Path", rpcPool.Token1Path), + "token0Balance": json.StringNode("token0Balance", rpcPool.Token0Balance), + "token1Balance": json.StringNode("token1Balance", rpcPool.Token1Balance), + "fee": json.NumberNode("fee", float64(rpcPool.Fee)), + "tickSpacing": json.NumberNode("tickSpacing", float64(rpcPool.TickSpacing)), + "maxLiquidityPerTick": json.StringNode("maxLiquidityPerTick", rpcPool.MaxLiquidityPerTick), + "sqrtPriceX96": json.StringNode("sqrtPriceX96", rpcPool.Slot0SqrtPriceX96), + "tick": json.NumberNode("tick", float64(rpcPool.Slot0Tick)), + "feeProtocol": json.NumberNode("feeProtocol", float64(rpcPool.Slot0FeeProtocol)), + "unlocked": json.BoolNode("unlocked", rpcPool.Slot0Unlocked), + "feeGrowthGlobal0X128": json.StringNode("feeGrowthGlobal0X128", rpcPool.FeeGrowthGlobal0X128), + "feeGrowthGlobal1X128": json.StringNode("feeGrowthGlobal1X128", rpcPool.FeeGrowthGlobal1X128), + "token0ProtocolFee": json.StringNode("token0ProtocolFee", rpcPool.Token0ProtocolFee), + "token1ProtocolFee": json.StringNode("token1ProtocolFee", rpcPool.Token1ProtocolFee), + "liquidity": json.StringNode("liquidity", rpcPool.Liquidity), + "ticks": json.ObjectNode("ticks", makeTicksJson(rpcPool.Ticks)), + "tickBitmaps": json.ObjectNode("tickBitmaps", makeRpcTickBitmapsJson(rpcPool.TickBitmaps)), + "positions": json.ArrayNode("positions", makeRpcPositionsArray(rpcPool.Positions)), + }) + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": makeStatNode(), + "response": responseNode, + }) + + return marshal(node) +} + +func makeStatNode() *json.Node { + return json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) +} + +func rpcMakePool(poolPath string) RpcPool { + rpcPool := RpcPool{} + pool := GetPoolFromPoolPath(poolPath) + + rpcPool.PoolPath = poolPath + + rpcPool.Token0Path = pool.token0Path + rpcPool.Token1Path = pool.token1Path + + rpcPool.Token0Balance = pool.balances.token0.ToString() + rpcPool.Token1Balance = pool.balances.token1.ToString() + + rpcPool.Fee = pool.fee + + rpcPool.TickSpacing = pool.tickSpacing + + rpcPool.MaxLiquidityPerTick = pool.maxLiquidityPerTick.ToString() + + rpcPool.Slot0SqrtPriceX96 = pool.slot0.sqrtPriceX96.ToString() + rpcPool.Slot0Tick = pool.slot0.tick + rpcPool.Slot0FeeProtocol = pool.slot0.feeProtocol + rpcPool.Slot0Unlocked = pool.slot0.unlocked + + rpcPool.FeeGrowthGlobal0X128 = pool.feeGrowthGlobal0X128.ToString() + rpcPool.FeeGrowthGlobal1X128 = pool.feeGrowthGlobal1X128.ToString() + + rpcPool.Token0ProtocolFee = pool.protocolFees.token0.ToString() + rpcPool.Token1ProtocolFee = pool.protocolFees.token1.ToString() + + rpcPool.Liquidity = pool.liquidity.ToString() + + rpcPool.Ticks = RpcTicks{} + pool.ticks.Iterate("", "", func(tickStr string, iTickInfo interface{}) bool { + tick, _ := strconv.Atoi(tickStr) + tickInfo := iTickInfo.(TickInfo) + + rpcPool.Ticks[int32(tick)] = RpcTickInfo{ + LiquidityGross: tickInfo.liquidityGross.ToString(), + LiquidityNet: tickInfo.liquidityNet.ToString(), + FeeGrowthOutside0X128: tickInfo.feeGrowthOutside0X128.ToString(), + FeeGrowthOutside1X128: tickInfo.feeGrowthOutside1X128.ToString(), + TickCumulativeOutside: tickInfo.tickCumulativeOutside, + SecondsPerLiquidityOutsideX: tickInfo.secondsPerLiquidityOutsideX128.ToString(), + SecondsOutside: tickInfo.secondsOutside, + Initialized: tickInfo.initialized, + } + + return false + }) + + rpcPool.TickBitmaps = RpcTickBitmaps{} + pool.tickBitmaps.Iterate("", "", func(tickStr string, iTickBitmap interface{}) bool { + tick, _ := strconv.Atoi(tickStr) + pool.setTickBitmap(int16(tick), iTickBitmap.(*u256.Uint)) + + return false + }) + + rpcPositions := []RpcPosition{} + pool.positions.Iterate("", "", func(posKey string, iPositionInfo interface{}) bool { + owner, tickLower, tickUpper := posKeyDivide(posKey) + posInfo := iPositionInfo.(PositionInfo) + + rpcPositions = append(rpcPositions, RpcPosition{ + Owner: owner, + TickLower: tickLower, + TickUpper: tickUpper, + Liquidity: posInfo.liquidity.ToString(), + Token0Owed: posInfo.tokensOwed0.ToString(), + Token1Owed: posInfo.tokensOwed1.ToString(), + }) + + return false + }) + + rpcPool.Positions = rpcPositions + + return rpcPool +} + +func posKeyDivide(posKey string) (string, int32, int32) { + // base64 decode + kDec, _ := b64.StdEncoding.DecodeString(posKey) + posKey = string(kDec) + + res := strings.Split(posKey, "__") + if len(res) != 3 { + panic(addDetailToError( + errInvalidPositionKey, + ufmt.Sprintf("_RPC_api.gno__posKeyDivide() || invalid posKey(%s)", posKey), + )) + } + + owner, _tickLower, _tickUpper := res[0], res[1], res[2] + + tickLower, _ := strconv.Atoi(_tickLower) + tickUpper, _ := strconv.Atoi(_tickUpper) + + return owner, int32(tickLower), int32(tickUpper) +} + +func makeTicksJson(ticks RpcTicks) map[string]*json.Node { + ticksJson := map[string]*json.Node{} + + for tick, tickInfo := range ticks { + ticksJson[strconv.Itoa(int(tick))] = json.ObjectNode("", map[string]*json.Node{ + "liquidityGross": json.StringNode("liquidityGross", tickInfo.LiquidityGross), + "liquidityNet": json.StringNode("liquidityNet", tickInfo.LiquidityNet), + "feeGrowthOutside0X128": json.StringNode("feeGrowthOutside0X128", tickInfo.FeeGrowthOutside0X128), + "feeGrowthOutside1X128": json.StringNode("feeGrowthOutside1X128", tickInfo.FeeGrowthOutside1X128), + "tickCumulativeOutside": json.NumberNode("tickCumulativeOutside", float64(tickInfo.TickCumulativeOutside)), + "secondsPerLiquidityOutsideX": json.StringNode("secondsPerLiquidityOutsideX", tickInfo.SecondsPerLiquidityOutsideX), + "secondsOutside": json.NumberNode("secondsOutside", float64(tickInfo.SecondsOutside)), + "initialized": json.BoolNode("initialized", tickInfo.Initialized), + }) + } + + return ticksJson +} + +func makeRpcTickBitmapsJson(tickBitmaps RpcTickBitmaps) map[string]*json.Node { + tickBitmapsJson := map[string]*json.Node{} + + for tick, tickBitmap := range tickBitmaps { + tickBitmapsJson[strconv.Itoa(int(tick))] = json.StringNode("", tickBitmap) + } + + return tickBitmapsJson +} + +func makeRpcPositionsArray(positions []RpcPosition) []*json.Node { + positionsJson := make([]*json.Node, len(positions)) + + for i, pos := range positions { + positionsJson[i] = json.ObjectNode("", map[string]*json.Node{ + "owner": json.StringNode("owner", pos.Owner), + "tickLower": json.NumberNode("tickLower", float64(pos.TickLower)), + "tickUpper": json.NumberNode("tickUpper", float64(pos.TickUpper)), + "liquidity": json.StringNode("liquidity", pos.Liquidity), + "token0Owed": json.StringNode("token0Owed", pos.Token0Owed), + "token1Owed": json.StringNode("token1Owed", pos.Token1Owed), + }) + } + + return positionsJson +} + +func marshal(node *json.Node) string { + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} diff --git a/contract/r/gnoswap/pool/api_test.gno b/contract/r/gnoswap/pool/api_test.gno new file mode 100644 index 000000000..ddda9cb8c --- /dev/null +++ b/contract/r/gnoswap/pool/api_test.gno @@ -0,0 +1,172 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/json" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +func TestInitTwoPools(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + SetPoolCreationFeeByAdmin(0) + + // foo:bar + CreatePool("gno.land/r/onbloc/bar", "gno.land/r/onbloc/foo", uint32(500), common.TickMathGetSqrtRatioAtTick(-10000).ToString()) + + // bar:baz + CreatePool(barPath, bazPath, fee500, "130621891405341611593710811006") // tick 10000 + + uassert.Equal(t, pools.Size(), 2) +} + +func TestApiGetPools(t *testing.T) { + getPools := ApiGetPools() + + root, err := json.Unmarshal([]byte(getPools)) + if err != nil { + panic(err.Error()) + } + + response, err := root.GetKey("response") + if err != nil { + panic(err.Error()) + } + + uassert.Equal(t, 2, response.Size()) // should be same as the number of pools +} + +func TestApiGetPool(t *testing.T) { + t.Run("existing pool", func(t *testing.T) { + getPool := ApiGetPool("gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500") + uassert.Equal(t, getPool, `{"stat":{"height":123,"timestamp":1234567890},"response":{"poolPath":"gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500","token0Path":"gno.land/r/onbloc/bar","token1Path":"gno.land/r/onbloc/foo","token0Balance":"0","token1Balance":"0","fee":500,"tickSpacing":10,"maxLiquidityPerTick":"1917569901783203986719870431555990","sqrtPriceX96":"48055510970269007215549348797","tick":-10000,"feeProtocol":0,"unlocked":true,"feeGrowthGlobal0X128":"0","feeGrowthGlobal1X128":"0","token0ProtocolFee":"0","token1ProtocolFee":"0","liquidity":"0","ticks":{},"tickBitmaps":{},"positions":[]}}`) + }) + + t.Run("non-existing pool", func(t *testing.T) { + getPool := ApiGetPool("gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:3000") + uassert.Equal(t, getPool, "") + }) +} + +func TestMakeStatNode(t *testing.T) { + t.Run("default block height and timestamp", func(t *testing.T) { + statNode := makeStatNode() + + statHeight, _ := statNode.GetKey("height") + uassert.Equal(t, statHeight.String(), "123") + + statTimestamp, _ := statNode.GetKey("timestamp") + uassert.Equal(t, statTimestamp.String(), "1234567890") + }) + + t.Run("increase block height and timestamp", func(t *testing.T) { + std.TestSkipHeights(1) + + statNode := makeStatNode() + + statHeight, _ := statNode.GetKey("height") + uassert.Equal(t, statHeight.String(), "124") + }) +} + +func TestRpcMakePool(t *testing.T) { + t.Run("existing pool", func(t *testing.T) { + rpcPool := rpcMakePool("gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500") + uassert.Equal(t, rpcPool.PoolPath, "gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500") + uassert.Equal(t, rpcPool.Token0Path, "gno.land/r/onbloc/bar") + uassert.Equal(t, rpcPool.Token1Path, "gno.land/r/onbloc/foo") + uassert.Equal(t, rpcPool.Fee, uint32(500)) + uassert.Equal(t, rpcPool.TickSpacing, int32(10)) + uassert.Equal(t, rpcPool.MaxLiquidityPerTick, "1917569901783203986719870431555990") + uassert.Equal(t, rpcPool.Slot0SqrtPriceX96, "48055510970269007215549348797") + uassert.Equal(t, rpcPool.Slot0Tick, int32(-10000)) + uassert.Equal(t, rpcPool.Slot0FeeProtocol, uint8(0)) + uassert.Equal(t, rpcPool.Slot0Unlocked, true) + uassert.Equal(t, rpcPool.FeeGrowthGlobal0X128, "0") + uassert.Equal(t, rpcPool.FeeGrowthGlobal1X128, "0") + uassert.Equal(t, rpcPool.Token0ProtocolFee, "0") + uassert.Equal(t, rpcPool.Token1ProtocolFee, "0") + uassert.Equal(t, rpcPool.Liquidity, "0") + + if len(rpcPool.Ticks) != 0 { + t.Errorf("expected 0 ticks, got %d", len(rpcPool.Ticks)) + } + + if len(rpcPool.TickBitmaps) != 0 { + t.Errorf("expected 0 tickBitmaps, got %d", len(rpcPool.TickBitmaps)) + } + }) + + t.Run("non-existing pool", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic, got %v", r) + } + }() + + rpcMakePool("gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:3000") + }) +} + +func TestMakeTicksJson(t *testing.T) { + t.Run("existing tick", func(t *testing.T) { + ticks := RpcTicks{ + 10000: { + LiquidityGross: "100000000", + LiquidityNet: "100000000", + FeeGrowthOutside0X128: "100000000", + FeeGrowthOutside1X128: "100000000", + TickCumulativeOutside: 100000000, + SecondsPerLiquidityOutsideX: "100000000", + SecondsOutside: 100000000, + Initialized: true, + }, + } + + ticksJson := makeTicksJson(ticks) + uassert.Equal(t, ticksJson["10000"].String(), `{"liquidityGross":"100000000","liquidityNet":"100000000","feeGrowthOutside0X128":"100000000","feeGrowthOutside1X128":"100000000","tickCumulativeOutside":100000000,"secondsPerLiquidityOutsideX":"100000000","secondsOutside":100000000,"initialized":true}`) + }) + + t.Run("non-existing tick", func(t *testing.T) { + ticksJson := makeTicksJson(RpcTicks{}) + uassert.Equal(t, len(ticksJson), 0) + }) +} + +func TestMakeRpcTickBitmapsJson(t *testing.T) { + t.Run("existing tickBitmap", func(t *testing.T) { + tickBitmaps := RpcTickBitmaps{ + 10000: "100000000", + 10001: "100000001", + } + + tickBitmapsJson := makeRpcTickBitmapsJson(tickBitmaps) + uassert.Equal(t, tickBitmapsJson["10000"].String(), `"100000000"`) + uassert.Equal(t, tickBitmapsJson["10001"].String(), `"100000001"`) + }) + + t.Run("non-existing tickBitmap", func(t *testing.T) { + tickBitmapsJson := makeRpcTickBitmapsJson(RpcTickBitmaps{}) + uassert.Equal(t, len(tickBitmapsJson), 0) + }) +} + +func TestMakeRpcPositionsArray(t *testing.T) { + t.Run("existing positions", func(t *testing.T) { + positions := []RpcPosition{ + {Owner: "gno.land/r/onbloc/bar", TickLower: 10000, TickUpper: 10001, Liquidity: "100000000", Token0Owed: "100000000", Token1Owed: "100000000"}, + } + + positionsJson := makeRpcPositionsArray(positions) + uassert.Equal(t, positionsJson[0].String(), `{"owner":"gno.land/r/onbloc/bar","tickLower":10000,"tickUpper":10001,"liquidity":"100000000","token0Owed":"100000000","token1Owed":"100000000"}`) + }) + + t.Run("non-existing positions", func(t *testing.T) { + positionsJson := makeRpcPositionsArray([]RpcPosition{}) + uassert.Equal(t, len(positionsJson), 0) + }) +} diff --git a/contract/r/gnoswap/pool/doc.gno b/contract/r/gnoswap/pool/doc.gno new file mode 100644 index 000000000..3692bf3fe --- /dev/null +++ b/contract/r/gnoswap/pool/doc.gno @@ -0,0 +1,2 @@ +// Package pool handles liquidity management, including creating new liquidity positions and withdrawing liquidity. +package pool diff --git a/contract/r/gnoswap/pool/errors.gno b/contract/r/gnoswap/pool/errors.gno new file mode 100644 index 000000000..c98286a4d --- /dev/null +++ b/contract/r/gnoswap/pool/errors.gno @@ -0,0 +1,52 @@ +package pool + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +// Error Messages +// These error messages are intended to inform you when an error occurs while using Contract and what it is. +var ( + errNoPermission = errors.New("[GNOSWAP-POOL-001] caller has no permission") + errNotRegistered = errors.New("[GNOSWAP-POOL-002] not registered token") + errAlreadyRegistered = errors.New("[GNOSWAP-POOL-003] already registered token") + errOutOfRange = errors.New("[GNOSWAP-POOL-005] out of range for numeric value") + errInvalidInput = errors.New("[GNOSWAP-POOL-006] invalid input data") + errInvalidPositionKey = errors.New("[GNOSWAP-POOL-007] invalid position key") + errDataNotFound = errors.New("[GNOSWAP-POOL-008] requested data not found") + errLiquidityCalculation = errors.New("[GNOSWAP-POOL-009] invalid liquidity calculated") + errZeroLiquidity = errors.New("[GNOSWAP-POOL-010] zero liquidity") + errDuplicateTokenInPool = errors.New("[GNOSWAP-POOL-011] same token used in single pool") + errTokenSortOrder = errors.New("[GNOSWAP-POOL-012] tokens must be in lexicographical order") + errPoolAlreadyExists = errors.New("[GNOSWAP-POOL-013] pool already created") + errInvalidSwapAmount = errors.New("[GNOSWAP-POOL-014] invalid swap amount") + errInvalidProtocolFeePct = errors.New("[GNOSWAP-POOL-015] invalid protocol fee percentage") + errInvalidWithdrawalFeePct = errors.New("[GNOSWAP-POOL-016] invalid withdrawal fee percentage") + errLockedPool = errors.New("[GNOSWAP-POOL-017] can't swap while pool is locked") + errPriceOutOfRange = errors.New("[GNOSWAP-POOL-018] swap price out of range") + errNotEnoughBalance = errors.New("[GNOSWAP-POOL-019] not enough balance to transfer") + errMustBeNegative = errors.New("[GNOSWAP-POOL-020] negative value expected") + errTransferFailed = errors.New("[GNOSWAP-POOL-021] token transfer failed") + errInvalidTickAndTickSpacing = errors.New("[GNOSWAP-POOL-022] invalid tick and tick spacing requested") + errInvalidAddress = errors.New("[GNOSWAP-POOL-023] invalid address") + errInvalidTickRange = errors.New("[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper") + errUnderflow = errors.New("[GNOSWAP-POOL-025] underflow") + errOverFlow = errors.New("[GNOSWAP-POOL-026] overflow") + errBalanceUpdateFailed = errors.New("[GNOSWAP-POOL-027] balance update failed") + errTickLowerInvalid = errors.New("[GNOSWAP-POOL-028] tickLower is invalid") + errTickUpperInvalid = errors.New("[GNOSWAP-POOL-029] tickUpper is invalid") + errTickLowerGtTickUpper = errors.New("[GNOSWAP-POOL-030] tickLower is greater than tickUpper") +) + +// addDetailToError adds detail to an error message +// +// input +// err: error - the error message that is one of above list +// detail: string - the detail to add to the error message +// Returns: string - the final error message +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/pool/getter.gno b/contract/r/gnoswap/pool/getter.gno new file mode 100644 index 000000000..6c22f1d15 --- /dev/null +++ b/contract/r/gnoswap/pool/getter.gno @@ -0,0 +1,128 @@ +package pool + +func PoolGetPoolList() []string { + poolPaths := []string{} + pools.Iterate("", "", func(poolPath string, _ interface{}) bool { + poolPaths = append(poolPaths, poolPath) + + return false + }) + + return poolPaths +} + +func PoolGetToken0Path(poolPath string) string { + return mustGetPool(poolPath).Token0Path() +} + +func PoolGetToken1Path(poolPath string) string { + return mustGetPool(poolPath).Token1Path() +} + +func PoolGetFee(poolPath string) uint32 { + return mustGetPool(poolPath).Fee() +} + +func PoolGetBalanceToken0(poolPath string) string { + return mustGetPool(poolPath).BalanceToken0().ToString() +} + +func PoolGetBalanceToken1(poolPath string) string { + return mustGetPool(poolPath).BalanceToken1().ToString() +} + +func PoolGetTickSpacing(poolPath string) int32 { + return mustGetPool(poolPath).TickSpacing() +} + +func PoolGetMaxLiquidityPerTick(poolPath string) string { + return mustGetPool(poolPath).MaxLiquidityPerTick().ToString() +} + +func PoolGetSlot0SqrtPriceX96(poolPath string) string { + return mustGetPool(poolPath).Slot0SqrtPriceX96().ToString() +} + +func PoolGetSlot0Tick(poolPath string) int32 { + return mustGetPool(poolPath).Slot0Tick() +} + +func PoolGetSlot0FeeProtocol(poolPath string) uint8 { + return mustGetPool(poolPath).Slot0FeeProtocol() +} + +func PoolGetSlot0Unlocked(poolPath string) bool { + return mustGetPool(poolPath).Slot0Unlocked() +} + +func PoolGetFeeGrowthGlobal0X128(poolPath string) string { + return mustGetPool(poolPath).FeeGrowthGlobal0X128().ToString() +} + +func PoolGetFeeGrowthGlobal1X128(poolPath string) string { + return mustGetPool(poolPath).FeeGrowthGlobal1X128().ToString() +} + +func PoolGetProtocolFeesToken0(poolPath string) string { + return mustGetPool(poolPath).ProtocolFeesToken0().ToString() +} + +func PoolGetProtocolFeesToken1(poolPath string) string { + return mustGetPool(poolPath).ProtocolFeesToken1().ToString() +} + +func PoolGetLiquidity(poolPath string) string { + return mustGetPool(poolPath).Liquidity().ToString() +} + +func PoolGetPositionLiquidity(poolPath, key string) string { + return mustGetPool(poolPath).PositionLiquidity(key).ToString() +} + +func PoolGetPositionFeeGrowthInside0LastX128(poolPath, key string) string { + return mustGetPool(poolPath).PositionFeeGrowthInside0LastX128(key).ToString() +} + +func PoolGetPositionFeeGrowthInside1LastX128(poolPath, key string) string { + return mustGetPool(poolPath).PositionFeeGrowthInside1LastX128(key).ToString() +} + +func PoolGetPositionTokensOwed0(poolPath, key string) string { + return mustGetPool(poolPath).PositionTokensOwed0(key).ToString() +} + +func PoolGetPositionTokensOwed1(poolPath, key string) string { + return mustGetPool(poolPath).PositionTokensOwed1(key).ToString() +} + +func PoolGetTickLiquidityGross(poolPath string, tick int32) string { + return mustGetPool(poolPath).GetTickLiquidityGross(tick).ToString() +} + +func PoolGetTickLiquidityNet(poolPath string, tick int32) string { + return mustGetPool(poolPath).GetTickLiquidityNet(tick).ToString() +} + +func PoolGetTickFeeGrowthOutside0X128(poolPath string, tick int32) string { + return mustGetPool(poolPath).GetTickFeeGrowthOutside0X128(tick).ToString() +} + +func PoolGetTickFeeGrowthOutside1X128(poolPath string, tick int32) string { + return mustGetPool(poolPath).GetTickFeeGrowthOutside1X128(tick).ToString() +} + +func PoolGetTickCumulativeOutside(poolPath string, tick int32) int64 { + return mustGetPool(poolPath).GetTickCumulativeOutside(tick) +} + +func PoolGetTickSecondsPerLiquidityOutsideX128(poolPath string, tick int32) string { + return mustGetPool(poolPath).GetTickSecondsPerLiquidityOutsideX128(tick).ToString() +} + +func PoolGetTickSecondsOutside(poolPath string, tick int32) uint32 { + return mustGetPool(poolPath).GetTickSecondsOutside(tick) +} + +func PoolGetTickInitialized(poolPath string, tick int32) bool { + return mustGetPool(poolPath).GetTickInitialized(tick) +} diff --git a/contract/r/gnoswap/pool/getter_test.gno b/contract/r/gnoswap/pool/getter_test.gno new file mode 100644 index 000000000..e0727669b --- /dev/null +++ b/contract/r/gnoswap/pool/getter_test.gno @@ -0,0 +1,491 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/avl" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestInitData(t *testing.T) { + resetObject(t) + + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + fee: 3000, + tickSpacing: 10, + maxLiquidityPerTick: u256.NewUint(1000000), + slot0: Slot0{ + sqrtPriceX96: u256.NewUint(1000000000000000000), + tick: 5, + feeProtocol: 6, + unlocked: false, + }, + feeGrowthGlobal0X128: u256.NewUint(1000000000000000000), + feeGrowthGlobal1X128: u256.NewUint(2000000000000000000), + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + liquidity: u256.NewUint(1000000), + } + + mockTicks := avl.NewTree() + mockTicks.Set("0", TickInfo{ + liquidityGross: u256.NewUint(1000000), + liquidityNet: i256.NewInt(2000000), + feeGrowthOutside0X128: u256.NewUint(3000000), + feeGrowthOutside1X128: u256.NewUint(4000000), + tickCumulativeOutside: 5, + secondsPerLiquidityOutsideX128: u256.NewUint(6000000), + secondsOutside: 7, + initialized: true, + }) + mockPool.ticks = mockTicks + + mockPositions := avl.NewTree() + mockPositions.Set("test_position", PositionInfo{ + liquidity: u256.NewUint(1000000), + feeGrowthInside0LastX128: u256.NewUint(2000000), + feeGrowthInside1LastX128: u256.NewUint(3000000), + tokensOwed0: u256.NewUint(4000000), + tokensOwed1: u256.NewUint(5000000), + }) + mockPool.positions = mockPositions + + pools.Set("token0:token1:3000", mockPool) +} + +func TestPoolGetters(t *testing.T) { + t.Run("get pool list", func(t *testing.T) { + poolList := PoolGetPoolList() + if len(poolList) != 1 || poolList[0] != "token0:token1:3000" { + t.Errorf("Expected pool list [%s], got %v", "token0:token1:3000", poolList) + } + }) + + t.Run("valid pool path", func(t *testing.T) { + validPoolPath := "token0:token1:3000" + + t.Run("get token0 path", func(t *testing.T) { + token0Path := PoolGetToken0Path(validPoolPath) + if token0Path != "token0" { + t.Errorf("Expected token0 path [%s], got %s", "token0", token0Path) + } + }) + + t.Run("get token1 path", func(t *testing.T) { + token1Path := PoolGetToken1Path(validPoolPath) + if token1Path != "token1" { + t.Errorf("Expected token1 path [%s], got %s", "token1", token1Path) + } + }) + + t.Run("get fee", func(t *testing.T) { + fee := PoolGetFee(validPoolPath) + if fee != 3000 { + t.Errorf("Expected fee [%d], got %d", 3000, fee) + } + }) + + t.Run("get balance token0", func(t *testing.T) { + balanceToken0 := PoolGetBalanceToken0(validPoolPath) + if balanceToken0 != "1000" { + t.Errorf("Expected balance token0 [%s], got %s", "1000", balanceToken0) + } + }) + + t.Run("get balance token1", func(t *testing.T) { + balanceToken1 := PoolGetBalanceToken1(validPoolPath) + if balanceToken1 != "2000" { + t.Errorf("Expected balance token1 [%s], got %s", "2000", balanceToken1) + } + }) + + t.Run("get tick spacing", func(t *testing.T) { + tickSpacing := PoolGetTickSpacing(validPoolPath) + if tickSpacing != 10 { + t.Errorf("Expected tick spacing [%d], got %d", 10, tickSpacing) + } + }) + + t.Run("get max liquidity per tick", func(t *testing.T) { + maxLiquidityPerTick := PoolGetMaxLiquidityPerTick(validPoolPath) + if maxLiquidityPerTick != "1000000" { + t.Errorf("Expected max liquidity per tick [%s], got %s", "1000000", maxLiquidityPerTick) + } + }) + + t.Run("get slot0 sqrt price x96", func(t *testing.T) { + sqrtPriceX96 := PoolGetSlot0SqrtPriceX96(validPoolPath) + if sqrtPriceX96 != "1000000000000000000" { + t.Errorf("Expected sqrt price x96 [%s], got %s", "1000000000000000000", sqrtPriceX96) + } + }) + + t.Run("get slot0 tick", func(t *testing.T) { + tick := PoolGetSlot0Tick(validPoolPath) + if tick != 5 { + t.Errorf("Expected tick [%d], got %d", 5, tick) + } + }) + + t.Run("get slot0 fee protocol", func(t *testing.T) { + feeProtocol := PoolGetSlot0FeeProtocol(validPoolPath) + if feeProtocol != 6 { + t.Errorf("Expected fee protocol [%d], got %d", 6, feeProtocol) + } + }) + + t.Run("get slot0 unlocked", func(t *testing.T) { + unlocked := PoolGetSlot0Unlocked(validPoolPath) + if unlocked != false { + t.Errorf("Expected unlocked [%t], got %t", false, unlocked) + } + }) + + t.Run("get fee growth global0x128", func(t *testing.T) { + feeGrowthGlobal0X128 := PoolGetFeeGrowthGlobal0X128(validPoolPath) + if feeGrowthGlobal0X128 != "1000000000000000000" { + t.Errorf("Expected fee growth global 0 x128 [%s], got %s", "1000000000000000000", feeGrowthGlobal0X128) + } + }) + + t.Run("get fee growth global1x128", func(t *testing.T) { + feeGrowthGlobal1X128 := PoolGetFeeGrowthGlobal1X128(validPoolPath) + if feeGrowthGlobal1X128 != "2000000000000000000" { + t.Errorf("Expected fee growth global 1 x128 [%s], got %s", "2000000000000000000", feeGrowthGlobal1X128) + } + }) + + t.Run("get protocol fees token0", func(t *testing.T) { + protocolFeesToken0 := PoolGetProtocolFeesToken0(validPoolPath) + if protocolFeesToken0 != "1000" { + t.Errorf("Expected protocol fees token0 [%s], got %s", "1000", protocolFeesToken0) + } + }) + + t.Run("get protocol fees token1", func(t *testing.T) { + protocolFeesToken1 := PoolGetProtocolFeesToken1(validPoolPath) + if protocolFeesToken1 != "2000" { + t.Errorf("Expected protocol fees token1 [%s], got %s", "2000", protocolFeesToken1) + } + }) + + t.Run("get liquidity", func(t *testing.T) { + liquidity := PoolGetLiquidity(validPoolPath) + if liquidity != "1000000" { + t.Errorf("Expected liquidity [%s], got %s", "1000000", liquidity) + } + }) + }) + + t.Run("invalid pool path - should panic", func(t *testing.T) { + invalidPoolPath := "invalid/pool/path" + + t.Run("get token0 path", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetToken0Path(invalidPoolPath) + }) + }) + + t.Run("get token1 path", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetToken1Path(invalidPoolPath) + }) + }) + + t.Run("get fee", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetFee(invalidPoolPath) + }) + }) + + t.Run("get balance token0", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetBalanceToken0(invalidPoolPath) + }) + }) + + t.Run("get balance token1", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetBalanceToken1(invalidPoolPath) + }) + }) + + t.Run("get tick spacing", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetTickSpacing(invalidPoolPath) + }) + }) + + t.Run("get liquidity", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetLiquidity(invalidPoolPath) + }) + }) + + t.Run("get max liquidity per tick", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetMaxLiquidityPerTick(invalidPoolPath) + }) + }) + + t.Run("get slot0 sqrt price x96", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetSlot0SqrtPriceX96(invalidPoolPath) + }) + }) + + t.Run("get slot0 tick", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetSlot0Tick(invalidPoolPath) + }) + }) + + t.Run("get slot0 fee protocol", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetSlot0FeeProtocol(invalidPoolPath) + }) + }) + + t.Run("get slot0 unlocked", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetSlot0Unlocked(invalidPoolPath) + }) + }) + + t.Run("get fee growth global0x128", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetFeeGrowthGlobal0X128(invalidPoolPath) + }) + }) + + t.Run("get fee growth global1x128", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetFeeGrowthGlobal1X128(invalidPoolPath) + }) + }) + + t.Run("get protocol fees token0", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetProtocolFeesToken0(invalidPoolPath) + }) + }) + + t.Run("get protocol fees token1", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetProtocolFeesToken1(invalidPoolPath) + }) + }) + + t.Run("get liquidity", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetLiquidity(invalidPoolPath) + }) + }) + }) +} + +func TestPositionGetters(t *testing.T) { + t.Run("valid position", func(t *testing.T) { + validPoolPath := "token0:token1:3000" + validPositionKey := "test_position" + + t.Run("get position liquidity", func(t *testing.T) { + liquidity := PoolGetPositionLiquidity(validPoolPath, validPositionKey) + if liquidity != "1000000" { + t.Errorf("Expected liquidity [%s], got %s", "1000000", liquidity) + } + }) + + t.Run("get position fee growth inside0 last x128", func(t *testing.T) { + feeGrowthInside0LastX128 := PoolGetPositionFeeGrowthInside0LastX128(validPoolPath, validPositionKey) + if feeGrowthInside0LastX128 != "2000000" { + t.Errorf("Expected fee growth inside0 last x128 [%s], got %s", "2000000", feeGrowthInside0LastX128) + } + }) + + t.Run("get position fee growth inside1 last x128", func(t *testing.T) { + feeGrowthInside1LastX128 := PoolGetPositionFeeGrowthInside1LastX128(validPoolPath, validPositionKey) + if feeGrowthInside1LastX128 != "3000000" { + t.Errorf("Expected fee growth inside1 last x128 [%s], got %s", "3000000", feeGrowthInside1LastX128) + } + }) + + t.Run("get position tokens owed0", func(t *testing.T) { + tokensOwed0 := PoolGetPositionTokensOwed0(validPoolPath, validPositionKey) + if tokensOwed0 != "4000000" { + t.Errorf("Expected tokens owed0 [%s], got %s", "4000000", tokensOwed0) + } + }) + + t.Run("get position tokens owed1", func(t *testing.T) { + tokensOwed1 := PoolGetPositionTokensOwed1(validPoolPath, validPositionKey) + if tokensOwed1 != "5000000" { + t.Errorf("Expected tokens owed1 [%s], got %s", "5000000", tokensOwed1) + } + }) + }) + + t.Run("invalid position", func(t *testing.T) { + validPoolPath := "token0:token1:3000" // pool must valid + invalidPositionKey := "invalid/position/key" // but position must not + + t.Run("get position liquidity", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetPositionLiquidity(validPoolPath, invalidPositionKey) + }) + }) + + t.Run("get position fee growth inside0 last x128", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetPositionFeeGrowthInside0LastX128(validPoolPath, invalidPositionKey) + }) + }) + + t.Run("get position fee growth inside1 last x128", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetPositionFeeGrowthInside1LastX128(validPoolPath, invalidPositionKey) + }) + }) + + t.Run("get position tokens owed0", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetPositionTokensOwed0(validPoolPath, invalidPositionKey) + }) + }) + + t.Run("get position tokens owed1", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetPositionTokensOwed1(validPoolPath, invalidPositionKey) + }) + }) + }) +} + +func TestTickGetters(t *testing.T) { + t.Run("valid tick", func(t *testing.T) { + validPoolPath := "token0:token1:3000" + validTick := int32(0) + + t.Run("get tick liquidity gross", func(t *testing.T) { + liquidityGross := PoolGetTickLiquidityGross(validPoolPath, validTick) + if liquidityGross != "1000000" { + t.Errorf("Expected liquidity gross [%s], got %s", "1000000", liquidityGross) + } + }) + + t.Run("get tick liquidity net", func(t *testing.T) { + liquidityNet := PoolGetTickLiquidityNet(validPoolPath, validTick) + if liquidityNet != "2000000" { + t.Errorf("Expected liquidity net [%s], got %s", "2000000", liquidityNet) + } + }) + + t.Run("get tick fee growth outside0 x128", func(t *testing.T) { + feeGrowthOutside0X128 := PoolGetTickFeeGrowthOutside0X128(validPoolPath, validTick) + if feeGrowthOutside0X128 != "3000000" { + t.Errorf("Expected fee growth outside0 x128 [%s], got %s", "3000000", feeGrowthOutside0X128) + } + }) + + t.Run("get tick fee growth outside1 x128", func(t *testing.T) { + feeGrowthOutside1X128 := PoolGetTickFeeGrowthOutside1X128(validPoolPath, validTick) + if feeGrowthOutside1X128 != "4000000" { + t.Errorf("Expected fee growth outside1 x128 [%s], got %s", "4000000", feeGrowthOutside1X128) + } + }) + + t.Run("get tick cumulative outside", func(t *testing.T) { + tickCumulativeOutside := PoolGetTickCumulativeOutside(validPoolPath, validTick) + if tickCumulativeOutside != 5 { + t.Errorf("Expected tick cumulative outside [%d], got %d", 5, tickCumulativeOutside) + } + }) + + t.Run("get tick seconds per liquidity outside x128", func(t *testing.T) { + secondsPerLiquidityOutsideX128 := PoolGetTickSecondsPerLiquidityOutsideX128(validPoolPath, validTick) + if secondsPerLiquidityOutsideX128 != "6000000" { + t.Errorf("Expected seconds per liquidity outside x128 [%s], got %s", "6000000", secondsPerLiquidityOutsideX128) + } + }) + + t.Run("get tick seconds outside", func(t *testing.T) { + secondsOutside := PoolGetTickSecondsOutside(validPoolPath, validTick) + if secondsOutside != 7 { + t.Errorf("Expected seconds outside [%d], got %d", 7, secondsOutside) + } + }) + + t.Run("get tick initialized", func(t *testing.T) { + initialized := PoolGetTickInitialized(validPoolPath, validTick) + if initialized != true { + t.Errorf("Expected initialized [%t], got %t", true, initialized) + } + }) + }) + + t.Run("invalid tick", func(t *testing.T) { + validPoolPath := "token0:token1:3000" // pool must valid + invalidTick := int32(-1) // but tick must not + + t.Run("get tick liquidity gross", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetTickLiquidityGross(validPoolPath, invalidTick) + }) + }) + + t.Run("get tick liquidity net", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetTickLiquidityNet(validPoolPath, invalidTick) + }) + }) + + t.Run("get tick fee growth outside0 x128", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetTickFeeGrowthOutside0X128(validPoolPath, invalidTick) + }) + }) + + t.Run("get tick cumulative outside", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetTickCumulativeOutside(validPoolPath, invalidTick) + }) + }) + + t.Run("get tick seconds per liquidity outside x128", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetTickSecondsPerLiquidityOutsideX128(validPoolPath, invalidTick) + }) + }) + + t.Run("get tick seconds outside", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetTickSecondsOutside(validPoolPath, invalidTick) + }) + }) + + t.Run("get tick initialized", func(t *testing.T) { + shouldPanic(t, func() { + PoolGetTickInitialized(validPoolPath, invalidTick) + }) + }) + }) +} + +func shouldPanic(t *testing.T, fn func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic") + } + }() + fn() +} diff --git a/contract/r/gnoswap/pool/gno.mod b/contract/r/gnoswap/pool/gno.mod new file mode 100644 index 000000000..3c88fce34 --- /dev/null +++ b/contract/r/gnoswap/pool/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/pool diff --git a/contract/r/gnoswap/pool/liquidity_math.gno b/contract/r/gnoswap/pool/liquidity_math.gno new file mode 100644 index 000000000..82a8b4317 --- /dev/null +++ b/contract/r/gnoswap/pool/liquidity_math.gno @@ -0,0 +1,61 @@ +package pool + +import ( + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/demo/ufmt" +) + +// liquidityMathAddDelta calculates the new liquidity by applying the delta liquidity to the current liquidity. +// If delta liquidity is negative, it subtracts the absolute value of delta liquidity from the current liquidity. +// If delta liquidity is positive, it adds the absolute value of delta liquidity to the current liquidity. +// +// Parameters: +// - x: The current liquidity as a uint256 value. +// - y: The delta liquidity as a signed int256 value. +// +// Returns: +// - The new liquidity as a uint256 value. +// +// Notes: +// - If `x` or `y` is nil, the function panics with an appropriate error message. +// - If `y` is negative, its absolute value is subtracted from `x`. +// - The result must be less than `x`. Otherwise, the function panics to prevent underflow. +// +// - If `y` is positive, it is added to `x`. +// - The result must be greater than or equal to `x`. Otherwise, the function panics to prevent overflow. +// +// - The function ensures correctness by validating the results of the arithmetic operations. +func liquidityMathAddDelta(x *u256.Uint, y *i256.Int) *u256.Uint { + if x == nil || y == nil { + panic(addDetailToError( + errInvalidInput, + "x or y is nil", + )) + } + + var z *u256.Uint + + // Subtract or add based on the sign of y + if y.Lt(i256.Zero()) { + absDelta := y.Abs() + z = new(u256.Uint).Sub(x, absDelta) + if z.Gte(x) { + panic(addDetailToError( + errLiquidityCalculation, + ufmt.Sprintf("Condition failed: (z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + )) + } + } else { + z = new(u256.Uint).Add(x, y.Abs()) + if z.Lt(x) { + panic(addDetailToError( + errLiquidityCalculation, + ufmt.Sprintf("Condition failed: (z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + )) + } + } + + return z +} diff --git a/contract/r/gnoswap/pool/liquidity_math_test.gno b/contract/r/gnoswap/pool/liquidity_math_test.gno new file mode 100644 index 000000000..6c38f3df7 --- /dev/null +++ b/contract/r/gnoswap/pool/liquidity_math_test.gno @@ -0,0 +1,76 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestLiquidityMathAddDelta(t *testing.T) { + tests := []struct { + name string + fn func() + wantPanic string + }{ + { + name: "x is nil", + fn: func() { + var y *i256.Int + y = i256.MustFromDecimal("100") + liquidityMathAddDelta(nil, y) + }, + wantPanic: addDetailToError(errInvalidInput, "x or y is nil"), + }, + { + name: "y is nil", + fn: func() { + var x *u256.Uint + x = u256.MustFromDecimal("100") + liquidityMathAddDelta(x, nil) + }, + wantPanic: addDetailToError(errInvalidInput, "x or y is nil"), + }, + { + name: "underflow panic with sub delta", + fn: func() { + x := u256.NewUint(0) + y := i256.MustFromDecimal("-100") + liquidityMathAddDelta(x, y) + }, + wantPanic: addDetailToError( + errLiquidityCalculation, + ufmt.Sprintf("Condition failed: (z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), + }, + { + name: "overflow panic with add delta", + fn: func() { + x := u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") // 2^256 - 1 + y := i256.MustFromDecimal("100") + liquidityMathAddDelta(x, y) + }, + wantPanic: addDetailToError( + errLiquidityCalculation, + ufmt.Sprintf("Condition failed: (z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Errorf("%s: expected panic but fot none", tt.name) + return + } + if r.(string) != tt.wantPanic { + t.Errorf("%s: got panic %v, want %v", tt.name, r, tt.wantPanic) + } + }() + + tt.fn() + }) + } +} diff --git a/contract/r/gnoswap/pool/pool.gno b/contract/r/gnoswap/pool/pool.gno new file mode 100644 index 000000000..99ddf5139 --- /dev/null +++ b/contract/r/gnoswap/pool/pool.gno @@ -0,0 +1,517 @@ +package pool + +import ( + "gno.land/p/demo/ufmt" + "std" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +// Mint adds liquidity to a pool by minting a new position. +// +// This function mints a liquidity position within the specified tick range in a pool. +// It verifies caller permissions, validates inputs, and updates the pool's state. Additionally, +// it transfers the required amounts of token0 and token1 from the caller to the pool. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the newly created position. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - liquidityAmount: string, the amount of liquidity to add, provided as a decimal string. +// - positionCaller: std.Address, the address of the entity calling the function (e.g., the position owner). +// +// Returns: +// - string: The amount of token 0 transferred to the pool as a string. +// - string: The amount of token 1 transferred to the pool as a string. +// +// Panic Conditions: +// - The system is halted (`common.IsHalted()`). +// - Caller lacks permission to mint a position when `common.GetLimitCaller()` is enforced. +// - The provided `liquidityAmount` is zero. +// - Any failure during token transfers or position modifications. +// +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#mint +func Mint( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + tickLower int32, + tickUpper int32, + liquidityAmount string, + positionCaller std.Address, +) (string, string) { + assertOnlyNotHalted() + if common.GetLimitCaller() { + assertOnlyPositionContract() + } + + liquidity := u256.MustFromDecimal(liquidityAmount) + if liquidity.IsZero() { + panic(errZeroLiquidity) + } + + pool := GetPool(token0Path, token1Path, fee) + liquidityDelta := safeConvertToInt128(liquidity) + position := newModifyPositionParams(recipient, tickLower, tickUpper, liquidityDelta) + _, amount0, amount1 := pool.modifyPosition(position) + + if amount0.Gt(u256.Zero()) { + pool.safeTransferFrom(positionCaller, consts.POOL_ADDR, pool.token0Path, amount0, true) + } + + if amount1.Gt(u256.Zero()) { + pool.safeTransferFrom(positionCaller, consts.POOL_ADDR, pool.token1Path, amount1, false) + } + + return amount0.ToString(), amount1.ToString() +} + +// Burn removes liquidity from a position in the pool. +// +// This function allows the caller to burn (remove) a specified amount of liquidity from a position. +// It calculates the amounts of token0 and token1 released when liquidity is removed and updates +// the position's owed token amounts. The actual transfer of tokens back to the caller happens +// during a separate `Collect()` operation. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - liquidityAmount: string, the amount of liquidity to remove, provided as a decimal string (uint128). +// +// Returns: +// - string: The amount of token0 owed after removing liquidity, as a string. +// - string: The amount of token1 owed after removing liquidity, as a string. +// +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#burn +func Burn( + token0Path string, + token1Path string, + fee uint32, + tickLower int32, + tickUpper int32, + liquidityAmount string, // uint128 +) (string, string) { // uint256 x2 + assertOnlyNotHalted() + if common.GetLimitCaller() { + assertOnlyPositionContract() + } + pool := GetPool(token0Path, token1Path, fee) + + caller := getPrevAddr() + liqAmount := u256.MustFromDecimal(liquidityAmount) + liqAmountInt256 := safeConvertToInt128(liqAmount) + liqDelta := i256.Zero().Neg(liqAmountInt256) + posParams := newModifyPositionParams(caller, tickLower, tickUpper, liqDelta) + position, amount0, amount1 := pool.modifyPosition(posParams) + + if amount0.Gt(u256.Zero()) || amount1.Gt(u256.Zero()) { + amount0 = toUint128(amount0) + amount1 = toUint128(amount1) + position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0) + position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, amount1) + } + + positionKey := getPositionKey(caller, tickLower, tickUpper) + pool.setPosition(positionKey, position) + + // actual token transfer happens in Collect() + return amount0.ToString(), amount1.ToString() +} + +// Collect handles the collection of tokens (token0 and token1) from a liquidity position. +// +// This function allows the caller to collect a specified amount of tokens owed to a position +// in a liquidity pool. It calculates the collectible amount based on three constraints: +// the requested amount, the tokens owed to the position, and the pool's available balance. +// The collected tokens are transferred to the specified recipient. +// +// Parameters: +// - token0Path: string, the path or identifier for token 0. +// - token1Path: string, the path or identifier for token 1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the collected tokens. +// - tickLower: int32, the lower tick boundary of the position. +// - tickUpper: int32, the upper tick boundary of the position. +// - amount0Requested: string, the requested amount of token 0 to collect (decimal string). +// - amount1Requested: string, the requested amount of token 1 to collect (decimal string). +// +// Returns: +// - string: The actual amount of token 0 collected, as a string. +// - string: The actual amount of token 1 collected, as a string. +// +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#collect +func Collect( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + tickLower int32, + tickUpper int32, + amount0Requested string, + amount1Requested string, +) (string, string) { + assertOnlyNotHalted() + if common.GetLimitCaller() { + assertOnlyPositionContract() + } + + pool := GetPool(token0Path, token1Path, fee) + positionKey := getPositionKey(getPrevAddr(), tickLower, tickUpper) + position := pool.mustGetPosition(positionKey) + + var amount0, amount1 *u256.Uint + + // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 + amount0Req := u256.MustFromDecimal(amount0Requested) + amount0 = collectToken(amount0Req, position.tokensOwed0, pool.BalanceToken0()) + amount1Req := u256.MustFromDecimal(amount1Requested) + amount1 = collectToken(amount1Req, position.tokensOwed1, pool.BalanceToken1()) + + if amount0.Gt(u256.Zero()) { + position.tokensOwed0 = new(u256.Uint).Sub(position.tokensOwed0, amount0) + pool.balances.token0 = new(u256.Uint).Sub(pool.balances.token0, amount0) + token0 := common.GetTokenTeller(pool.token0Path) + checkTransferError(token0.Transfer(recipient, amount0.Uint64())) + } + if amount1.Gt(u256.Zero()) { + position.tokensOwed1 = new(u256.Uint).Sub(position.tokensOwed1, amount1) + pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) + token1 := common.GetTokenTeller(pool.token1Path) + checkTransferError(token1.Transfer(recipient, amount1.Uint64())) + } + + pool.setPosition(positionKey, position) + + return amount0.ToString(), amount1.ToString() +} + +// collectToken calculates the actual amount of tokens that can be collected. +// +// This function determines the smallest possible value among the requested amount (`amountReq`), +// the tokens owed (`tokensOwed`), and the pool's available balance (`poolBalance`). It ensures +// the collected amount does not exceed any of these constraints. +// +// Parameters: +// - amountReq: *u256.Uint, the amount of tokens requested for collection. +// - tokensOwed: *u256.Uint, the total amount of tokens owed to the position. +// - poolBalance: *u256.Uint, the current balance of tokens available in the pool. +// +// Returns: +// - amount: *u256.Uint, the actual amount that can be collected (minimum of the three inputs). +func collectToken( + amountReq, tokensOwed, poolBalance *u256.Uint, +) (amount *u256.Uint) { + // find smallest of three amounts + amount = u256Min(amountReq, tokensOwed) + amount = u256Min(amount, poolBalance) + return amount.Clone() +} + +// SetFeeProtocolByAdmin sets the fee protocol for all pools +// Also it will be applied to new created pools +func SetFeeProtocolByAdmin( + feeProtocol0 uint8, + feeProtocol1 uint8, +) { + assertOnlyAdmin() + setFeeProtocolInternal(feeProtocol0, feeProtocol1, "SetFeeProtocolByAdmin") +} + +// SetFeeProtocol sets the fee protocol for all pools +// Only governance contract can execute this function via proposal +// Also it will be applied to new created pools +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#setfeeprotocol +func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { + assertOnlyGovernance() + setFeeProtocolInternal(feeProtocol0, feeProtocol1, "SetFeeProtocol") +} + +// setFeeProtocolInternal updates the protocol fee for all pools and emits an event. +// +// This function is an internal utility used to set the protocol fee for token0 and token1 in a compact +// format. The fee values are stored as a single `uint8` byte where: +// - Lower 4 bits represent the fee for token0 (feeProtocol0). +// - Upper 4 bits represent the fee for token1 (feeProtocol1). +// +// It also emits an event to log the changes, including the previous and new fee protocol values. +// +// Parameters: +// - feeProtocol0: uint8, protocol fee for token0 (must be 0 or between 4 and 10 inclusive). +// - feeProtocol1: uint8, protocol fee for token1 (must be 0 or between 4 and 10 inclusive). +// - eventName: string, the name of the event to emit (e.g., "SetFeeProtocolByAdmin"). +// +// Notes: +// - This function is called by higher-level functions like `SetFeeProtocolByAdmin` or `SetFeeProtocol`. +// - It does not validate caller permissions; validation must be performed by the calling function. +func setFeeProtocolInternal(feeProtocol0, feeProtocol1 uint8, eventName string) { + oldFee := slot0FeeProtocol + newFee := setFeeProtocol(feeProtocol0, feeProtocol1) + + feeProtocol0Old := oldFee % 16 + feeProtocol1Old := oldFee >> 4 + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + eventName, + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevFeeProtocol0", formatUint(uint64(feeProtocol0Old)), + "prevFeeProtocol1", formatUint(uint64(feeProtocol1Old)), + "feeProtocol0", formatUint(uint64(feeProtocol0)), + "feeProtocol1", formatUint(uint64(feeProtocol1)), + "newFee", formatUint(uint64(newFee)), + ) +} + +// setFeeProtocol updates the protocol fee configuration for all managed pools. +// +// This function combines the protocol fee values for token0 and token1 into a single `uint8` value, +// where: +// - Lower 4 bits store feeProtocol0 (for token0). +// - Upper 4 bits store feeProtocol1 (for token1). +// +// The updated fee protocol is applied uniformly to all pools managed by the system. +// +// Parameters: +// - feeProtocol0: protocol fee for token0 (must be 0 or between 4 and 10 inclusive). +// - feeProtocol1: protocol fee for token1 (must be 0 or between 4 and 10 inclusive). +// +// Returns: +// - newFee (uint8): the combined fee protocol value. +// +// Example: +// If feeProtocol0 = 4 and feeProtocol1 = 5: +// +// newFee = 4 + (5 << 4) +// // Results in: 0x54 (84 in decimal) +// // Binary: 0101 0100 +// // ^^^^ ^^^^ +// // fee1=5 fee0=4 +// +// Notes: +// - This function ensures that all pools under management are updated to use the same fee protocol. +// - Caller restrictions (e.g., admin or governance) are not enforced in this function. +// - Ensure the system is not halted before updating fees. +func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { + assertOnlyNotHalted() + if err := validateFeeProtocol(feeProtocol0, feeProtocol1); err != nil { + panic(addDetailToError( + err, + ufmt.Sprintf("expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), + )) + } + // combine both protocol fee into a single byte: + // - feePrtocol0 occupies the lower 4 bits + // - feeProtocol1 is shifted the lower 4 positions to occupy the upper 4 bits + newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) + + // Update slot0 for each pool + pools.Iterate("", "", func(poolPath string, iPool interface{}) bool { + pool := iPool.(*Pool) + pool.slot0.feeProtocol = newFee + + return false + }) + + // update slot0 + slot0FeeProtocol = newFee + return newFee +} + +// validateFeeProtocol validates the fee protocol values for token0 and token1. +// +// This function checks whether the provided fee protocol values (`feeProtocol0` and `feeProtocol1`) +// are valid using the `isValidFeeProtocolValue` function. If either value is invalid, it returns +// an error indicating that the protocol fee percentage is invalid. +// +// Parameters: +// - feeProtocol0: uint8, the fee protocol value for token0. +// - feeProtocol1: uint8, the fee protocol value for token1. +// +// Returns: +// - error: Returns `errInvalidProtocolFeePct` if either `feeProtocol0` or `feeProtocol1` is invalid. +// Returns `nil` if both values are valid. +func validateFeeProtocol(feeProtocol0, feeProtocol1 uint8) error { + if !isValidFeeProtocolValue(feeProtocol0) || !isValidFeeProtocolValue(feeProtocol1) { + return errInvalidProtocolFeePct + } + return nil +} + +// isValidFeeProtocolValue checks if a fee protocol value is within acceptable range. +// valid values are either 0 or between 4 and 10 inclusive. +func isValidFeeProtocolValue(value uint8) bool { + return value == 0 || (value >= 4 && value <= 10) +} + +// CollectProtocolByAdmin collects protocol fees for the given pool that accumulated while it was being used for swap +// Returns collected amount0, amount1 in string +func CollectProtocolByAdmin( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + amount0Requested string, + amount1Requested string, +) (string, string) { + assertOnlyAdmin() + amount0, amount1 := collectProtocol( + token0Path, + token1Path, + fee, + recipient, + amount0Requested, + amount1Requested, + ) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "CollectProtocolByAdmin", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "token0Path", token0Path, + "token1Path", token1Path, + "fee", formatUint(fee), + "recipient", recipient.String(), + "collectedAmount0", amount0, + "collectedAmount1", amount1, + ) + + return amount0, amount1 +} + +// CollectProtocol collects protocol fees for the given pool that accumulated while it was being used for swap +// Only governance contract can execute this function via proposal +// Returns collected amount0, amount1 in string +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#collectprotocol +func CollectProtocol( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + amount0Requested string, // uint128 + amount1Requested string, // uint128 +) (string, string) { // uint128 x2 + assertOnlyGovernance() + amount0, amount1 := collectProtocol( + token0Path, + token1Path, + fee, + recipient, + amount0Requested, + amount1Requested, + ) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "CollectProtocol", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "token0Path", token0Path, + "token1Path", token1Path, + "fee", formatUint(fee), + "recipient", recipient.String(), + "internal_amount0", amount0, + "internal_amount1", amount1, + ) + + return amount0, amount1 +} + +// collectProtocol collects protocol fees for token0 and token1 from the specified pool. +// +// This function allows the collection of accumulated protocol fees for token0 and token1. It ensures +// the requested amounts do not exceed the available protocol fees in the pool and transfers the +// collected amounts to the specified recipient. +// +// Parameters: +// - token0Path: string, the path or identifier for token0. +// - token1Path: string, the path or identifier for token1. +// - fee: uint32, the fee tier of the pool. +// - recipient: std.Address, the address to receive the collected protocol fees. +// - amount0Requested: string, the requested amount of token0 to collect (decimal string). +// - amount1Requested: string, the requested amount of token1 to collect (decimal string). +// +// Returns: +// - string: The actual amount of token0 collected, as a string. +// - string: The actual amount of token1 collected, as a string. +func collectProtocol( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + amount0Requested string, + amount1Requested string, +) (string, string) { + assertOnlyRegistered(token0Path) + assertOnlyRegistered(token1Path) + assertOnlyNotHalted() + + pool := GetPool(token0Path, token1Path, fee) + + amount0Req := u256.MustFromDecimal(amount0Requested) + amount1Req := u256.MustFromDecimal(amount1Requested) + + amount0 := u256Min(amount0Req, pool.ProtocolFeesToken0()) + amount1 := u256Min(amount1Req, pool.ProtocolFeesToken1()) + + amount0, amount1 = pool.saveProtocolFees(amount0.Clone(), amount1.Clone()) + uAmount0 := amount0.Uint64() + uAmount1 := amount1.Uint64() + + token0Teller := common.GetTokenTeller(pool.token0Path) + checkTransferError(token0Teller.Transfer(recipient, uAmount0)) + newBalanceToken0, err := updatePoolBalance(pool.BalanceToken0(), pool.BalanceToken1(), amount0, true) + if err != nil { + panic(err) + } + pool.balances.token0 = newBalanceToken0 + + token1Teller := common.GetTokenTeller(pool.token1Path) + checkTransferError(token1Teller.Transfer(recipient, uAmount1)) + newBalanceToken1, err := updatePoolBalance(pool.BalanceToken0(), pool.BalanceToken1(), amount1, false) + if err != nil { + panic(err) + } + pool.balances.token1 = newBalanceToken1 + + return amount0.ToString(), amount1.ToString() +} + +// saveProtocolFees updates the protocol fee balances after collection. +// +// Parameters: +// - amount0: amount of token0 fees to collect +// - amount1: amount of token1 fees to collect +// +// Returns the adjusted amounts that will actually be collected for both tokens. +func (p *Pool) saveProtocolFees(amount0, amount1 *u256.Uint) (*u256.Uint, *u256.Uint) { + cond01 := amount0.Gt(u256.Zero()) + cond02 := amount0.Eq(p.ProtocolFeesToken0()) + if cond01 && cond02 { + amount0 = new(u256.Uint).Sub(amount0, u256.One()) + } + + cond11 := amount1.Gt(u256.Zero()) + cond12 := amount1.Eq(p.ProtocolFeesToken1()) + if cond11 && cond12 { + amount1 = new(u256.Uint).Sub(amount1, u256.One()) + } + + p.protocolFees.token0 = new(u256.Uint).Sub(p.ProtocolFeesToken0(), amount0) + p.protocolFees.token1 = new(u256.Uint).Sub(p.ProtocolFeesToken1(), amount1) + + // return rest fee + return amount0, amount1 +} diff --git a/contract/r/gnoswap/pool/pool_manager.gno b/contract/r/gnoswap/pool/pool_manager.gno new file mode 100644 index 000000000..52c0763b8 --- /dev/null +++ b/contract/r/gnoswap/pool/pool_manager.gno @@ -0,0 +1,357 @@ +package pool + +import ( + "std" + "strconv" + "strings" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + en "gno.land/r/gnoswap/v1/emission" + pf "gno.land/r/gnoswap/v1/protocol_fee" + + "gno.land/r/gnoswap/v1/gns" +) + +var ( + feeAmountTickSpacing = avl.NewTree() // feeBps(uint32) -> tickSpacing(int32) + pools = avl.NewTree() // poolPath -> *Pool + + slot0FeeProtocol uint8 = 0 +) + +func init() { + setFeeAmountTickSpacing(100, 1) // 0.01% + setFeeAmountTickSpacing(500, 10) // 0.05% + setFeeAmountTickSpacing(3000, 60) // 0.3% + setFeeAmountTickSpacing(10000, 200) // 1% +} + +// createPoolParams holds the essential parameters for creating a new pool. +type createPoolParams struct { + token0Path, token1Path string + fee uint32 + sqrtPriceX96 *u256.Uint + tickSpacing int32 +} + +// newPoolParams defines the essential parameters for creating a new pool. +func newPoolParams( + token0Path string, + token1Path string, + fee uint32, + sqrtPriceX96 string, +) *createPoolParams { + price := u256.MustFromDecimal(sqrtPriceX96) + tickSpacing := GetFeeAmountTickSpacing(fee) + return &createPoolParams{ + token0Path: token0Path, + token1Path: token1Path, + fee: fee, + sqrtPriceX96: price, + tickSpacing: tickSpacing, + } +} + +func (p *createPoolParams) updateWithWrapping() *createPoolParams { + token0Path, token1Path := p.wrap() + + if !p.isInOrder() { + token0Path, token1Path = token1Path, token0Path + + oldTick := common.TickMathGetTickAtSqrtRatio(p.sqrtPriceX96) + if oldTick > consts.MAX_TICK || oldTick < consts.MIN_TICK { + panic(addDetailToError( + errInvalidTickRange, + ufmt.Sprintf("expected tick(%d) to be within range", oldTick), + )) + } + newPrice := common.TickMathGetSqrtRatioAtTick(oldTick * int32(-1)) + if newPrice.IsZero() { + panic(addDetailToError( + errOverFlow, + ufmt.Sprintf("expected newPrice(%s) to be non-zero", newPrice.ToString()), + )) + } + p.sqrtPriceX96 = new(u256.Uint).Set(newPrice) + } + + return newPoolParams(token0Path, token1Path, p.fee, p.sqrtPriceX96.ToString()) +} + +func (p *createPoolParams) isSameTokenPath() bool { + return p.token0Path == p.token1Path +} + +// isInOrder checks if token paths are in lexicographical (or, alphabetical) order +func (p *createPoolParams) isInOrder() bool { + if strings.Compare(p.token0Path, p.token1Path) < 0 { + return true + } + + return false +} + +func (p *createPoolParams) wrap() (string, string) { + if p.token0Path == consts.GNOT { + p.token0Path = consts.WRAPPED_WUGNOT + } + if p.token1Path == consts.GNOT { + p.token1Path = consts.WRAPPED_WUGNOT + } + + return p.token0Path, p.token1Path +} + +func (p *createPoolParams) Token0Path() string { + return p.token0Path +} + +func (p *createPoolParams) Token1Path() string { + return p.token1Path +} + +func (p *createPoolParams) Fee() uint32 { + return p.fee +} + +func (p *createPoolParams) TickSpacing() int32 { + return p.tickSpacing +} + +func (p *createPoolParams) SqrtPriceX96() *u256.Uint { + return p.sqrtPriceX96 +} + +// CreatePool creates a new concentrated liquidity pool with the given parameters. +// It mints and distributes GNS tokens, validates the input parameters, and creates a new pool. +// If GNOT is used as one of the tokens, it is automatically wrapped to WUGNOT. +// The function ensures that token0Path is lexicographically smaller than token1Path. +// ref: https://docs.gnoswap.io/contracts/pool/pool_manager.gno#createpool +func CreatePool( + token0Path string, + token1Path string, + fee uint32, + sqrtPriceX96 string, +) { + assertOnlyNotHalted() + + poolInfo := newPoolParams(token0Path, token1Path, fee, sqrtPriceX96) + poolInfo = poolInfo.updateWithWrapping() + if poolInfo.isSameTokenPath() { + panic(addDetailToError( + errDuplicateTokenInPool, + ufmt.Sprintf( + "expected token0Path(%s) != token1Path(%s)", + poolInfo.token0Path, poolInfo.token1Path, + ), + )) + } + en.MintAndDistributeGns() + + // wrap first + token0Path, token1Path = poolInfo.wrap() + poolPath := GetPoolPath(token0Path, token1Path, fee) + + // reinitialize poolInfo with wrapped tokens + poolInfo = newPoolParams(token0Path, token1Path, fee, sqrtPriceX96) + + // then check if token0Path == token1Path + if poolInfo.isSameTokenPath() { + panic(addDetailToError( + errDuplicateTokenInPool, + ufmt.Sprintf( + "expected token0Path(%s) != token1Path(%s)", + token0Path, token1Path, + ), + )) + } + + if !poolInfo.isInOrder() { + panic(addDetailToError( + errTokenSortOrder, + ufmt.Sprintf("expected token0Path(%s) < token1Path(%s)", token0Path, token1Path), + )) + } + + prevAddr, prevRealm := getPrevAsString() + + // check whether the pool already exist + pool, exist := pools.Get(poolPath) + if exist { + panic(addDetailToError( + errPoolAlreadyExists, + ufmt.Sprintf("expected poolPath(%s) not to exist", poolPath), + )) + } + + if poolCreationFee > 0 { + gns.TransferFrom(std.PrevRealm().Addr(), consts.PROTOCOL_FEE_ADDR, poolCreationFee) + pf.AddToProtocolFee(consts.GNS_PATH, poolCreationFee) + + std.Emit( + "PoolCreationFee", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "poolPath", poolPath, + "feeTokenPath", consts.GNS_PATH, + "feeAmount", formatUint(poolCreationFee), + ) + } + + pool = newPool(poolInfo) + pools.Set(poolPath, pool) + + currentTick := common.TickMathGetTickAtSqrtRatio(sqrtPriceX96) + + std.Emit( + "CreatePool", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "token0Path", token0Path, + "token1Path", token1Path, + "fee", formatUint(uint64(fee)), + "sqrtPriceX96", sqrtPriceX96, + "poolPath", poolPath, + "tick", formatInt(int64(currentTick)), + "tickSpacing", formatInt(int64(poolInfo.TickSpacing())), + ) +} + +// DoesPoolPathExist checks if a pool exists for the given poolPath. +// The poolPath is a unique identifier for a pool, combining token paths and fee. +func DoesPoolPathExist(poolPath string) bool { + return pools.Has(poolPath) +} + +// GetPool retrieves a pool instance based on the provided token paths and fee tier. +// +// This function determines the pool path by combining the paths of token0 and token1 along with the fee tier, +// and then retrieves the corresponding pool instance using that path. +// +// Parameters: +// - token0Path (string): The unique path for token0. +// - token1Path (string): The unique path for token1. +// - fee (uint32): The fee tier for the pool, expressed in basis points (e.g., 3000 for 0.3%). +// +// Returns: +// - *Pool: A pointer to the Pool instance corresponding to the provided tokens and fee tier. +// +// Notes: +// - The order of token paths (token0Path and token1Path) matters and should match the pool's configuration. +// - Ensure that the tokens and fee tier provided are valid and registered in the system. +// +// Example: +// pool := GetPool("gno.land/r/demo/wugnot", "gno.land/r/gnoswap/v1/gns", 3000) +func GetPool(token0Path, token1Path string, fee uint32) *Pool { + poolPath := GetPoolPath(token0Path, token1Path, fee) + return GetPoolFromPoolPath(poolPath) +} + +// GetPoolFromPoolPath retrieves a pool instance based on the provided pool path. +// +// This function checks if a pool exists for the given poolPath in the `pools` mapping. +// If the pool exists, it returns the pool instance. Otherwise, it panics with a descriptive error. +// +// Parameters: +// - poolPath (string): The unique identifier or path for the pool. +// +// Returns: +// - *Pool: A pointer to the Pool instance corresponding to the given poolPath. +// +// Panics: +// - If the `poolPath` does not exist in the `pools` mapping, it panics with an error message +// indicating that the expected poolPath was not found. +// +// Notes: +// - Ensure that the `poolPath` provided is valid and corresponds to an existing pool in the `pools` mapping. +// +// Example: +// pool := GetPoolFromPoolPath("path/to/pool") +func GetPoolFromPoolPath(poolPath string) *Pool { + iPool, exist := pools.Get(poolPath) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), + )) + } + return iPool.(*Pool) +} + +// GetPoolPath generates a unique pool path string based on the token paths and fee tier. +// +// This function ensures that the token paths are registered and sorted in alphabetical order +// before combining them with the fee tier to create a unique identifier for the pool. +// +// Parameters: +// - token0Path (string): The unique identifier or path for token0. +// - token1Path (string): The unique identifier or path for token1. +// - fee (uint32): The fee tier for the pool, expressed in basis points (e.g., 3000 for 0.3%). +// +// Returns: +// - string: A unique pool path string in the format "token0Path:token1Path:fee". +// +// Notes: +// - The function validates that both `token0Path` and `token1Path` are registered in the system +// using `common.MustRegistered`. +// - The token paths are sorted alphabetically to ensure consistent pool path generation, regardless +// of the input order. +// - This sorting guarantees that the pool path remains deterministic for the same pair of tokens and fee. +// +// Example: +// poolPath := GetPoolPath("path/to/token0", "path/to/token1", 3000) +// // Output: "path/to/token0:path/to/token1:3000" +func GetPoolPath(token0Path, token1Path string, fee uint32) string { + assertOnlyRegistered(token0Path) + assertOnlyRegistered(token1Path) + + // all the token paths in the pool are sorted in alphabetical order. + if strings.Compare(token1Path, token0Path) < 0 { + token0Path, token1Path = token1Path, token0Path + } + return ufmt.Sprintf("%s:%s:%d", token0Path, token1Path, fee) +} + +// GetFeeAmountTickSpacing retrieves the tick spacing associated with a given fee amount. +// The tick spacing determines the minimum distance between ticks in the pool. +// +// Parameters: +// - fee (uint32): The fee tier in basis points (e.g., 3000 for 0.3%) +// +// Returns: +// - int32: The tick spacing value for the given fee tier +// +// Panics: +// - If the fee amount is not registered in feeAmountTickSpacing +func GetFeeAmountTickSpacing(fee uint32) int32 { + feeStr := strconv.FormatUint(uint64(fee), 10) + iTickSpacing, exist := feeAmountTickSpacing.Get(feeStr) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("expected feeAmountTickSpacing(%s) to exist", feeStr), + )) + } + + return iTickSpacing.(int32) +} + +// setFeeAmountTickSpacing associates a tick spacing value with a fee amount. +// This is typically called during initialization to set up supported fee tiers. +// +// Parameters: +// - fee (uint32): The fee tier in basis points (e.g., 3000 for 0.3%) +// - tickSpacing (int32): The minimum tick spacing for this fee tier +// +// Note: Smaller tick spacing allows for more granular price points but increases +// computational overhead. Higher fee tiers typically use larger tick spacing. +func setFeeAmountTickSpacing(fee uint32, tickSpacing int32) { + feeStr := strconv.FormatUint(uint64(fee), 10) + feeAmountTickSpacing.Set(feeStr, tickSpacing) +} diff --git a/contract/r/gnoswap/pool/pool_manager_test.gno b/contract/r/gnoswap/pool/pool_manager_test.gno new file mode 100644 index 000000000..4022c858c --- /dev/null +++ b/contract/r/gnoswap/pool/pool_manager_test.gno @@ -0,0 +1,239 @@ +package pool + +import ( + "std" + "strings" + "testing" + + "gno.land/p/gnoswap/consts" +) + +func TestNewPoolParams(t *testing.T) { + params := newPoolParams( + "token0", + "token1", + 500, // 0.05% fee + "1000000", // example sqrt price + ) + + if params.Token0Path() != "token0" { + t.Errorf("Expected token0Path to be 'token0', got %s", params.Token0Path()) + } + + if params.TickSpacing() != 10 { // 500 fee should have 10 tick spacing + t.Errorf("Expected tick spacing 10, got %d", params.TickSpacing()) + } + + if !params.isInOrder() { + t.Errorf("Expected token0Path(token0) < token1Path(token1)") + } + + params = newPoolParams( + consts.GNOT, + "token1", + 500, + "1000000", + ) + token0, token1 := params.wrap() + if token0 != consts.WRAPPED_WUGNOT { + t.Errorf("Expected GNOT to be wrapped to WUGNOT") + } + + params = newPoolParams( + "token0", + "token0", + 500, + "1000000", + ) + if !params.isSameTokenPath() { + t.Errorf("Expected token0Path(token0) == token1Path(token0)") + } +} + +func TestGetPoolPath(t *testing.T) { + path := GetPoolPath("gno.land/r/onbloc/bar", "gno.land/r/onbloc/foo", 500) + expected := "gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500" + if path != expected { + t.Errorf("Expected path %s, got %s", expected, path) + } + + path = GetPoolPath("gno.land/r/onbloc/foo", "gno.land/r/onbloc/bar", 500) + if path != expected { + t.Errorf("Expected tokens to be sorted, expected %s, got %s", expected, path) + } +} + +func TestTickSpacingMap(t *testing.T) { + tests := []struct { + fee uint32 + tickSpacing int32 + }{ + {100, 1}, // 0.01% + {500, 10}, // 0.05% + {3000, 60}, // 0.3% + {10000, 200}, // 1% + } + + for _, tt := range tests { + spacing := GetFeeAmountTickSpacing(tt.fee) + if spacing != tt.tickSpacing { + t.Errorf("For fee %d, expected tick spacing %d, got %d", + tt.fee, tt.tickSpacing, spacing) + } + } +} + +func TestCreatePool(t *testing.T) { + tests := []struct { + name string + token0Path string + token1Path string + fee uint32 + sqrtPrice string + shouldPanic bool + panicMsg string + inOrder bool + }{ + { + name: "success - normal token pair", + token0Path: barPath, + token1Path: fooPath, + fee: 3000, + sqrtPrice: "4295128740", + }, + { + name: "fail - same tokens", + token0Path: barPath, + token1Path: barPath, + fee: 3000, + sqrtPrice: "4295128740", + shouldPanic: true, + panicMsg: "[GNOSWAP-POOL-011] same token used in single pool || expected token0Path(gno.land/r/onbloc/bar) != token1Path(gno.land/r/onbloc/bar", + }, + { + name: "success - when tokens not in order, reverse order and create pool", + token0Path: fooPath, + token1Path: bazPath, + fee: 3000, + sqrtPrice: "4295128740", + shouldPanic: false, + panicMsg: "[GNOSWAP-POOL-012] tokens must be in lexicographical order || expected token0Path(gno.land/r/onbloc/foo) < token1Path(gno.land/r/onbloc/baz)", + inOrder: true, + }, + { + name: "fail - pool already exists", + token0Path: barPath, + token1Path: fooPath, + fee: 3000, + sqrtPrice: "4295128740", + shouldPanic: true, + panicMsg: "[GNOSWAP-POOL-013] pool already created || expected poolPath(gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:3000) not to exist", + }, + } + + std.TestSetRealm(std.NewUserRealm(consts.ADMIN)) + SetPoolCreationFeeByAdmin(0) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + defer func() { + r := recover() + if r == nil { + t.Errorf("expected panic but got none") + return + } + errMsg := r.(string) + if !strings.Contains(errMsg, tt.panicMsg) { + t.Errorf("expected panic message containing %q but got %q", tt.panicMsg, errMsg) + } + }() + } + + CreatePool(tt.token0Path, tt.token1Path, tt.fee, tt.sqrtPrice) + + if !tt.shouldPanic { + // verify pool was created correctly + poolPath := GetPoolPath(tt.token0Path, tt.token1Path, tt.fee) + pool := mustGetPool(poolPath) + + // check if GNOT was properly wrapped + expectedToken0 := tt.token0Path + expectedToken1 := tt.token1Path + if tt.inOrder { + expectedToken0, expectedToken1 = expectedToken1, expectedToken0 + } + + if expectedToken0 == consts.GNOT { + expectedToken0 = consts.WRAPPED_WUGNOT + } + if expectedToken1 == consts.GNOT { + expectedToken1 = consts.WRAPPED_WUGNOT + } + + if pool.token0Path != expectedToken0 || pool.token1Path != expectedToken1 { + t.Errorf("incorrect token paths in pool. got %s,%s want %s,%s", + pool.token0Path, pool.token1Path, expectedToken0, expectedToken1) + } + } + }) + } + + resetObject(t) +} + +func TestGetPool(t *testing.T) { + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) + shouldPanic bool + expected string + }{ + { + name: "Panic - unregisterd poolPath ", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + }, + action: func(t *testing.T) { + GetPool(barPath, fooPath, fee500) + }, + shouldPanic: true, + expected: "[GNOSWAP-POOL-008] requested data not found || expected poolPath(gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500) to exist", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tt.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + case error: + if r.(error).Error() != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r.(error).Error(), tt.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + } + }() + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + } + }) + } +} diff --git a/contract/r/gnoswap/pool/pool_test.gno b/contract/r/gnoswap/pool/pool_test.gno new file mode 100644 index 000000000..ffed6d435 --- /dev/null +++ b/contract/r/gnoswap/pool/pool_test.gno @@ -0,0 +1,165 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" +) + +func TestMint(t *testing.T) { + token0Path := "test_token0" + token1Path := "test_token1" + fee := uint32(3000) + recipient := testutils.TestAddress("recipient") + tickLower := int32(-100) + tickUpper := int32(100) + liquidityAmount := "100000" + + t.Run("unauthorized caller mint should fail", func(t *testing.T) { + unauthorized := testutils.TestAddress("unauthorized") + defer func() { + if r := recover(); r == nil { + t.Error("unauthorized caller mint should fail") + } + }() + + Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, liquidityAmount, unauthorized) + }) + + t.Run("mint with 0 liquidity should fail", func(t *testing.T) { + authorized := consts.POSITION_ADDR + defer func() { + if r := recover(); r == nil { + t.Error("mint with 0 liquidity should fail") + } + }() + + Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, "0", authorized) + }) +} + +func TestBurn(t *testing.T) { + // Setup + originalGetPool := GetPool + defer func() { + GetPool = originalGetPool + }() + + // Mock data + mockCaller := consts.POSITION_ADDR + mockPosition := PositionInfo{ + liquidity: u256.NewUint(1000), + tokensOwed0: u256.NewUint(0), + tokensOwed1: u256.NewUint(0), + } + mockPool := &Pool{ + positions: avl.NewTree(), + } + + GetPool = func(token0Path, token1Path string, fee uint32) *Pool { + return mockPool + } + + tests := []struct { + name string + liquidityAmount string + tickLower int32 + tickUpper int32 + expectedAmount0 string + expectedAmount1 string + expectPanic bool + }{ + { + name: "successful burn", + liquidityAmount: "500", + tickLower: -100, + tickUpper: 100, + expectedAmount0: "100", + expectedAmount1: "200", + }, + { + name: "zero liquidity", + liquidityAmount: "0", + tickLower: -100, + tickUpper: 100, + expectPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == "successful burn" { + t.Skip("skipping until find better way to test this") + } + + // setup position for this test + posKey := getPositionKey(mockCaller, tt.tickLower, tt.tickUpper) + mockPool.positions.Set(posKey, mockPosition) + + if tt.expectPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but got none") + } + }() + } + + amount0, amount1 := Burn( + "token0", + "token1", + 3000, + tt.tickLower, + tt.tickUpper, + tt.liquidityAmount, + ) + + if !tt.expectPanic { + if amount0 != tt.expectedAmount0 { + t.Errorf("expected amount0 %s, got %s", tt.expectedAmount0, amount0) + } + if amount1 != tt.expectedAmount1 { + t.Errorf("expected amount1 %s, got %s", tt.expectedAmount1, amount1) + } + + newPosition := mockPool.mustGetPosition(posKey) + if newPosition.tokensOwed0.IsZero() { + t.Error("expected tokensOwed0 to be updated") + } + if newPosition.tokensOwed1.IsZero() { + t.Error("expected tokensOwed1 to be updated") + } + } + }) + } +} + +func TestSetFeeProtocolInternal(t *testing.T) { + tests := []struct { + name string + feeProtocol0 uint8 + feeProtocol1 uint8 + eventName string + }{ + { + name: "set fee protocol by admin", + feeProtocol0: 4, + feeProtocol1: 5, + eventName: "SetFeeProtocolByAdmin", + }, + } + + for _, tt := range tests { + t.Run("set fee protocol by admin", func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + SetFeeProtocolByAdmin(tt.feeProtocol0, tt.feeProtocol1) + uassert.Equal(t, tt.feeProtocol0, pool.Slot0FeeProtocol()%16) + uassert.Equal(t, tt.feeProtocol1, pool.Slot0FeeProtocol()>>4) + }) + } +} diff --git a/contract/r/gnoswap/pool/pool_transfer.gno b/contract/r/gnoswap/pool/pool_transfer.gno new file mode 100644 index 000000000..bd8d2e827 --- /dev/null +++ b/contract/r/gnoswap/pool/pool_transfer.gno @@ -0,0 +1,203 @@ +package pool + +import ( + "std" + + "gno.land/p/demo/ufmt" + + "gno.land/r/gnoswap/v1/common" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +// safeTransfer performs a token transfer out of the pool while ensuring +// the pool has sufficient balance and updating internal accounting. +// This function is typically used during swaps and liquidity removals. +// +// Important requirements: +// - The amount must be negative (representing an outflow from the pool) +// - The pool must have sufficient balance for the transfer +// - The transfer amount must fit within uint64 range +// +// Parameters: +// - to: destination address for the transfer +// - tokenPath: path identifier of the token to transfer +// - amount: amount to transfer (must be negative) +// - isToken0: true if transferring token0, false for token1 +// +// The function will: +// 1. Validate the amount is negative +// 2. Check pool has sufficient balance +// 3. Execute the transfer +// 4. Update pool's internal balance +// +// Panics if any validation fails or if the transfer fails +func (p *Pool) safeTransfer( + to std.Address, + tokenPath string, + amount *i256.Int, + isToken0 bool, +) { + if amount.Gt(i256.Zero()) { + panic(ufmt.Sprintf( + "%v. got: %s", errMustBeNegative, amount.ToString(), + )) + } + + absAmount := amount.Abs() + + token0 := p.BalanceToken0() + token1 := p.BalanceToken1() + + if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { + panic(err) + } + amountUint64 := safeConvertToUint64(absAmount) + + token := common.GetTokenTeller(tokenPath) + checkTransferError(token.Transfer(to, amountUint64)) + + newBalance, err := updatePoolBalance(token0, token1, absAmount, isToken0) + if err != nil { + panic(err) + } + + if isToken0 { + p.balances.token0 = newBalance + } else { + p.balances.token1 = newBalance + } +} + +// safeTransferFrom securely transfers tokens into the pool while ensuring balance consistency. +// +// This function performs the following steps: +// 1. Validates and converts the transfer amount to `uint64` using `safeConvertToUint64`. +// 2. Executes the token transfer using `TransferFrom` via the token teller contract. +// 3. Verifies that the destination balance reflects the correct amount after transfer. +// 4. Updates the pool's internal balances (`token0` or `token1`) and validates the updated state. +// +// Parameters: +// - from (std.Address): Source address for the token transfer. +// - to (std.Address): Destination address, typically the pool address. +// - tokenPath (string): Path identifier for the token being transferred. +// - amount (*u256.Uint): The amount of tokens to transfer (must be a positive value). +// - isToken0 (bool): A flag indicating whether the token being transferred is token0 (`true`) or token1 (`false`). +// +// Panics: +// - If the `amount` exceeds the uint64 range during conversion. +// - If the token transfer (`TransferFrom`) fails. +// - If the destination balance after the transfer does not match the expected amount. +// - If the pool's internal balances (`token0` or `token1`) overflow or become inconsistent. +// +// Notes: +// - The function assumes that the sender (`from`) has approved the pool to spend the specified tokens. +// - The balance consistency check ensures that no tokens are lost or double-counted during the transfer. +// - Pool balance updates are performed atomically to ensure internal consistency. +// +// Example: +// p.safeTransferFrom( +// +// sender, poolAddress, "path/to/token0", u256.MustFromDecimal("1000"), true +// +// ) +func (p *Pool) safeTransferFrom( + from, to std.Address, + tokenPath string, + amount *u256.Uint, + isToken0 bool, +) { + amountUint64 := safeConvertToUint64(amount) + + token := common.GetToken(tokenPath) + beforeBalance := token.BalanceOf(to) + + teller := common.GetTokenTeller(tokenPath) + checkTransferError(teller.TransferFrom(from, to, amountUint64)) + + afterBalance := token.BalanceOf(to) + if (beforeBalance + amountUint64) != afterBalance { + panic(ufmt.Sprintf( + "%v. beforeBalance(%d) + amount(%d) != afterBalance(%d)", + errTransferFailed, beforeBalance, amountUint64, afterBalance, + )) + } + + // update pool balances + if isToken0 { + beforeToken0 := p.balances.token0.Clone() + p.balances.token0 = new(u256.Uint).Add(p.balances.token0, amount) + if p.balances.token0.Lt(beforeToken0) { + panic(ufmt.Sprintf( + "%v. token0(%s) < beforeToken0(%s)", + errBalanceUpdateFailed, p.balances.token0.ToString(), beforeToken0.ToString(), + )) + } + } else { + beforeToken1 := p.balances.token1.Clone() + p.balances.token1 = new(u256.Uint).Add(p.balances.token1, amount) + if p.balances.token1.Lt(beforeToken1) { + panic(ufmt.Sprintf( + "%v. token1(%s) < beforeToken1(%s)", + errBalanceUpdateFailed, p.balances.token1.ToString(), beforeToken1.ToString(), + )) + } + } +} + +// validatePoolBalance checks if the pool has sufficient balance of either token0 and token1 +// before proceeding with a transfer. This prevents the pool won't go into a negative balance. +func validatePoolBalance(token0, token1, amount *u256.Uint, isToken0 bool) error { + if isToken0 { + if token0.Lt(amount) { + return ufmt.Errorf( + "%v. token0(%s) >= amount(%s)", + errTransferFailed, token0.ToString(), amount.ToString(), + ) + } + return nil + } + if token1.Lt(amount) { + return ufmt.Errorf( + "%v. token1(%s) >= amount(%s)", + errTransferFailed, token1.ToString(), amount.ToString(), + ) + } + return nil +} + +// updatePoolBalance calculates the new balance after a transfer and validate. +// It ensures the resulting balance won't be negative or overflow. +func updatePoolBalance( + token0, token1, amount *u256.Uint, + isToken0 bool, +) (*u256.Uint, error) { + var overflow bool + var newBalance *u256.Uint + + if isToken0 { + newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) + if isBalanceOverflowOrNegative(overflow, newBalance) { + return nil, ufmt.Errorf( + "%v. cannot decrease, token0(%s) - amount(%s)", + errBalanceUpdateFailed, token0.ToString(), amount.ToString(), + ) + } + return newBalance, nil + } + + newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) + if isBalanceOverflowOrNegative(overflow, newBalance) { + return nil, ufmt.Errorf( + "%v. cannot decrease, token1(%s) - amount(%s)", + errBalanceUpdateFailed, token1.ToString(), amount.ToString(), + ) + } + return newBalance, nil +} + +// isBalanceOverflowOrNegative checks if the balance calculation resulted in an overflow or negative value. +func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { + return overflow || newBalance.Lt(u256.Zero()) +} diff --git a/contract/r/gnoswap/pool/pool_transfer_test.gno b/contract/r/gnoswap/pool/pool_transfer_test.gno new file mode 100644 index 000000000..325e3d131 --- /dev/null +++ b/contract/r/gnoswap/pool/pool_transfer_test.gno @@ -0,0 +1,297 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" +) + +func TestTransferAndVerify(t *testing.T) { + // Setup common test data + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(1000), + }, + } + + t.Run("validatePoolBalance", func(t *testing.T) { + tests := []struct { + name string + amount *u256.Uint + isToken0 bool + expectedError bool + }{ + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: true, + expectedError: false, + }, + { + name: "must panic for insufficient token0 balance", + amount: u256.NewUint(1500), + isToken0: true, + expectedError: true, + }, + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: false, + expectedError: false, + }, + { + name: "must panic for insufficient token1 balance", + amount: u256.NewUint(1500), + isToken0: false, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token0 := pool.balances.token0 + token1 := pool.balances.token1 + + err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) + if err != nil { + if !tt.expectedError { + t.Errorf("unexpected error: %v", err) + } + } + }) + } + }) +} + +func TestTransferFromAndVerify(t *testing.T) { + tests := []struct { + name string + pool *Pool + from std.Address + to std.Address + tokenPath string + amount *i256.Int + isToken0 bool + expectedBal0 *u256.Uint + expectedBal1 *u256.Uint + }{ + { + name: "normal token0 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(500), + isToken0: true, + expectedBal0: u256.NewUint(1500), // 1000 + 500 + expectedBal1: u256.NewUint(2000), // unchanged + }, + { + name: "normal token1 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(800), + isToken0: false, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2800), // 2000 + 800 + }, + { + name: "zero value transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(0), + isToken0: true, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2000), // unchanged + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + TokenFaucet(t, tt.tokenPath, tt.from) + TokenApprove(t, tt.tokenPath, tt.from, pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) + + tt.pool.safeTransferFrom(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) + + if !tt.pool.balances.token0.Eq(tt.expectedBal0) { + t.Errorf("token0 balance mismatch: expected %s, got %s", + tt.expectedBal0.ToString(), + tt.pool.balances.token0.ToString()) + } + + if !tt.pool.balances.token1.Eq(tt.expectedBal1) { + t.Errorf("token1 balance mismatch: expected %s, got %s", + tt.expectedBal1.ToString(), + tt.pool.balances.token1.ToString()) + } + }) + } + + t.Run("negative value handling", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + negativeAmount := i256.NewInt(-500) + + TokenFaucet(t, fooPath, testutils.TestAddress("from_addr")) + TokenApprove(t, fooPath, testutils.TestAddress("from_addr"), consts.POOL_ADDR, u256.MustFromDecimal(negativeAmount.Abs().ToString()).Uint64()) + pool.safeTransferFrom( + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), + fooPath, + u256.MustFromDecimal(negativeAmount.Abs().ToString()), + true, + ) + + expectedBal := u256.NewUint(1500) // 1000 + 500 (absolute value) + if !pool.balances.token0.Eq(expectedBal) { + t.Errorf("negative amount handling failed: expected %s, got %s", + expectedBal.ToString(), + pool.balances.token0.ToString()) + } + }) + + t.Run("uint64 overflow value", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + hugeAmount := i256.FromUint256(u256.MustFromDecimal("18446744073709551616")) // 2^64 + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for amount exceeding uint64 range") + } + }() + + pool.safeTransferFrom( + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), + fooPath, + u256.MustFromDecimal(hugeAmount.ToString()), + true, + ) + }) +} + +func TestUpdatePoolBalance(t *testing.T) { + tests := []struct { + name string + initialToken0 *u256.Uint + initialToken1 *u256.Uint + amount *u256.Uint + isToken0 bool + expectedBal *u256.Uint + expectErr bool + }{ + { + name: "normal token0 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(300), + isToken0: true, + expectedBal: u256.NewUint(700), + expectErr: false, + }, + { + name: "normal token1 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(500), + isToken0: false, + expectedBal: u256.NewUint(1500), + expectErr: false, + }, + { + name: "insufficient token0 balance", + initialToken0: u256.NewUint(100), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(200), + isToken0: true, + expectedBal: nil, + expectErr: true, + }, + { + name: "insufficient token1 balance", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(100), + amount: u256.NewUint(200), + isToken0: false, + expectedBal: nil, + expectErr: true, + }, + { + name: "zero value handling", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(0), + isToken0: true, + expectedBal: u256.NewUint(1000), + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: tt.initialToken0, + token1: tt.initialToken1, + }, + } + + newBal, err := updatePoolBalance(tt.initialToken0, tt.initialToken1, tt.amount, tt.isToken0) + + if tt.expectErr { + if err == nil { + t.Errorf("%s: expected error but no error", tt.name) + } + return + } + if err != nil { + t.Errorf("%s: unexpected error: %v", tt.name, err) + return + } + + if !newBal.Eq(tt.expectedBal) { + t.Errorf("%s: balance mismatch, expected: %s, actual: %s", + tt.name, + tt.expectedBal.ToString(), + newBal.ToString(), + ) + } + }) + } +} diff --git a/contract/r/gnoswap/pool/position.gno b/contract/r/gnoswap/pool/position.gno new file mode 100644 index 000000000..4be23e633 --- /dev/null +++ b/contract/r/gnoswap/pool/position.gno @@ -0,0 +1,183 @@ +package pool + +import ( + "encoding/base64" + "std" + + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +var ( + Q128 = u256.MustFromDecimal(consts.Q128) +) + +// getPositionKey generates a unique, encoded key for a liquidity position. +// +// This function creates a unique key for identifying a liquidity position in a pool. The key is based +// on the position's owner address, lower tick, and upper tick values. The generated key is then encoded +// as a base64 string to ensure compatibility and uniqueness. +// +// Parameters: +// - owner: std.Address, the address of the position's owner. +// - tickLower: int32, the lower tick boundary for the position. +// - tickUpper: int32, the upper tick boundary for the position. +// +// Returns: +// - string: A base64-encoded string representing the unique position key. +// +// Workflow: +// 1. Validates that the `owner` address is valid using `assertOnlyValidAddress`. +// 2. Ensures `tickLower` is less than `tickUpper` using `assertTickLowerLessThanUpper`. +// 3. Constructs the position key as a formatted string: +// "____" +// 4. Encodes the generated position key into a base64 string for safety and uniqueness. +// 5. Returns the encoded position key. +// +// Example: +// +// owner := std.Address("0x123456789") +// positionKey := getPositionKey(owner, 100, 200) +// fmt.Println("Position Key:", positionKey) +// // Output: base64-encoded string representing "0x123456789__100__200" +// +// Notes: +// - The base64 encoding ensures that the position key can be safely used as an identifier +// across different systems or data stores. +// - The function will panic if: +// - The `owner` address is invalid. +// - `tickLower` is greater than or equal to `tickUpper`. +func getPositionKey( + owner std.Address, + tickLower int32, + tickUpper int32, +) string { + assertOnlyValidAddress(owner) + assertTickLowerLessThanUpper(tickLower, tickUpper) + + positionKey := ufmt.Sprintf("%s__%d__%d", owner.String(), tickLower, tickUpper) + encodedPositionKey := base64.StdEncoding.EncodeToString([]byte(positionKey)) + return encodedPositionKey +} + +// positionUpdate calculates and returns an updated PositionInfo. +func positionUpdate( + position PositionInfo, + liquidityDelta *i256.Int, + feeGrowthInside0X128 *u256.Uint, + feeGrowthInside1X128 *u256.Uint, +) PositionInfo { + position.valueOrZero() + + var liquidityNext *u256.Uint + if liquidityDelta.IsZero() { + if position.liquidity.IsZero() { + panic(addDetailToError( + errZeroLiquidity, + "both liquidityDelta and current position's liquidity are zero", + )) + } + + liquidityNext = position.liquidity + } else { + liquidityNext = liquidityMathAddDelta(position.liquidity, liquidityDelta) + } + + tokensOwed0 := u256.Zero() + diff0 := new(u256.Uint).Sub(feeGrowthInside0X128, position.feeGrowthInside0LastX128) + tokensOwed0 = u256.MulDiv(diff0, position.liquidity, Q128) + + tokensOwed1 := u256.Zero() + diff1 := new(u256.Uint).Sub(feeGrowthInside1X128, position.feeGrowthInside1LastX128) + tokensOwed1 = u256.MulDiv(diff1, position.liquidity, Q128) + + if !(liquidityDelta.IsZero()) { + position.liquidity = liquidityNext + } + + position.feeGrowthInside0LastX128 = feeGrowthInside0X128 + position.feeGrowthInside1LastX128 = feeGrowthInside1X128 + if tokensOwed0.Gt(u256.Zero()) || tokensOwed1.Gt(u256.Zero()) { + position.tokensOwed0 = position.tokensOwed0.Add(position.tokensOwed0, tokensOwed0) + position.tokensOwed1 = position.tokensOwed1.Add(position.tokensOwed1, tokensOwed1) + } + + return position +} + +// positionUpdateWithKey updates a position in the pool and returns the updated position. +func (p *Pool) positionUpdateWithKey( + positionKey string, + liquidityDelta *i256.Int, + feeGrowthInside0X128 *u256.Uint, + feeGrowthInside1X128 *u256.Uint, +) PositionInfo { + // if pointer is nil, set to zero for calculation + liquidityDelta = liquidityDelta.NilToZero() + feeGrowthInside0X128 = feeGrowthInside0X128.NilToZero() + feeGrowthInside1X128 = feeGrowthInside1X128.NilToZero() + + positionToUpdate, _ := p.GetPosition(positionKey) + positionAfterUpdate := positionUpdate(positionToUpdate, liquidityDelta, feeGrowthInside0X128, feeGrowthInside1X128) + + p.setPosition(positionKey, positionAfterUpdate) + + return positionAfterUpdate +} + +// PositionLiquidity returns the liquidity of a position. +func (p *Pool) PositionLiquidity(key string) *u256.Uint { + return p.mustGetPosition(key).liquidity +} + +// PositionFeeGrowthInside0LastX128 returns the fee growth of token0 inside a position. +func (p *Pool) PositionFeeGrowthInside0LastX128(key string) *u256.Uint { + return p.mustGetPosition(key).feeGrowthInside0LastX128 +} + +// PositionFeeGrowthInside1LastX128 returns the fee growth of token1 inside a position. +func (p *Pool) PositionFeeGrowthInside1LastX128(key string) *u256.Uint { + return p.mustGetPosition(key).feeGrowthInside1LastX128 +} + +// PositionTokensOwed0 returns the amount of token0 owed by a position. +func (p *Pool) PositionTokensOwed0(key string) *u256.Uint { + return p.mustGetPosition(key).tokensOwed0 +} + +// PositionTokensOwed1 returns the amount of token1 owed by a position. +func (p *Pool) PositionTokensOwed1(key string) *u256.Uint { + return p.mustGetPosition(key).tokensOwed1 +} + +// GetPosition returns the position info for a given key. +func (p *Pool) GetPosition(key string) (PositionInfo, bool) { + iPositionInfo, exist := p.positions.Get(key) + if !exist { + newPosition := PositionInfo{} + newPosition.valueOrZero() + return newPosition, false + } + + return iPositionInfo.(PositionInfo), true +} + +// setPosition sets the position info for a given key. +func (p *Pool) setPosition(posKey string, positionInfo PositionInfo) { + p.positions.Set(posKey, positionInfo) +} + +// mustGetPosition returns the position info for a given key. +func (p *Pool) mustGetPosition(positionKey string) PositionInfo { + positionInfo, exist := p.GetPosition(positionKey) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("positionKey(%s) does not exist", positionKey), + )) + } + return positionInfo +} diff --git a/contract/r/gnoswap/pool/position_modify.gno b/contract/r/gnoswap/pool/position_modify.gno new file mode 100644 index 000000000..18ff96324 --- /dev/null +++ b/contract/r/gnoswap/pool/position_modify.gno @@ -0,0 +1,119 @@ +package pool + +import ( + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + plp "gno.land/p/gnoswap/gnsmath" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/common" +) + +// modifyPosition updates a position in the pool and calculates the amount of tokens +// needed (for minting) or returned (for burning). The calculation depends on the current +// price (tick) relative to the position's price range. +// +// The function handles three cases: +// 1. Current price below range (tick < tickLower): only token0 is used/returned +// 2. Current price in range (tickLower <= tick < tickUpper): both tokens are used/returned +// 3. Current price above range (tick >= tickUpper): only token1 is used/returned +// +// Parameters: +// - params: ModifyPositionParams containing owner, tickLower, tickUpper, and liquidityDelta +// +// Returns: +// - PositionInfo: updated position information +// - *u256.Uint: amount of token0 needed/returned +// - *u256.Uint: amount of token1 needed/returned +func (p *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { + checkTicks(params.tickLower, params.tickUpper) + + // get current state and price bounds + tick := p.Slot0Tick() + // update position state + position := p.updatePosition(params, tick) + liqDelta := params.liquidityDelta + if liqDelta.IsZero() { + return position, u256.Zero(), u256.Zero() + } + + amount0, amount1 := i256.Zero(), i256.Zero() + + // covert ticks to sqrt price to use in amount calculations + // price = 1.0001^tick, but we use sqrtPriceX96 + sqrtRatioLower := common.TickMathGetSqrtRatioAtTick(params.tickLower) + sqrtRatioUpper := common.TickMathGetSqrtRatioAtTick(params.tickUpper) + sqrtPriceX96 := p.Slot0SqrtPriceX96() + + // calculate token amounts based on current price position relative to range + switch { + case tick < params.tickLower: + // case 1 + // full range between lower and upper tick is used for token0 + // current tick is below the passed range; liquidity can only become in range by crossing from left to + // right, when we'll need _more_ token0 (it's becoming more valuable) so user must provide it + amount0 = calculateToken0Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) + + case tick < params.tickUpper: + // case 2 + liquidityBefore := p.liquidity + // token0 used from current price to upper tick + amount0 = calculateToken0Amount(sqrtPriceX96, sqrtRatioUpper, liqDelta) + // token1 used from lower tick to current price + amount1 = calculateToken1Amount(sqrtRatioLower, sqrtPriceX96, liqDelta) + // update pool's active liquidity since price is in range + p.liquidity = liquidityMathAddDelta(liquidityBefore, liqDelta) + + default: + // case 3 + // full range between lower and upper tick is used for token1 + // current tick is above the passed range; liquidity can only become in range by crossing from right to + // left, when we'll need _more_ token1 (it's becoming more valuable) so user must provide it + amount1 = calculateToken1Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) + } + + return position, amount0.Abs(), amount1.Abs() +} + +func calculateToken0Amount(sqrtPriceLower, sqrtPriceUpper *u256.Uint, liquidityDelta *i256.Int) *i256.Int { + res := plp.SqrtPriceMathGetAmount0DeltaStr(sqrtPriceLower, sqrtPriceUpper, liquidityDelta) + return i256.MustFromDecimal(res) +} + +func calculateToken1Amount(sqrtPriceLower, sqrtPriceUpper *u256.Uint, liquidityDelta *i256.Int) *i256.Int { + res := plp.SqrtPriceMathGetAmount1DeltaStr(sqrtPriceLower, sqrtPriceUpper, liquidityDelta) + return i256.MustFromDecimal(res) +} + +func checkTicks(tickLower, tickUpper int32) { + assertTickLowerLessThanUpper(tickLower, tickUpper) + assertValidTickLower(tickLower) + assertValidTickUpper(tickUpper) +} + +func assertTickLowerLessThanUpper(tickLower, tickUpper int32) { + if tickLower >= tickUpper { + panic(addDetailToError( + errInvalidTickRange, + ufmt.Sprintf("tickLower(%d), tickUpper(%d)", tickLower, tickUpper), + )) + } +} + +func assertValidTickLower(tickLower int32) { + if tickLower < consts.MIN_TICK { + panic(addDetailToError( + errTickLowerInvalid, + ufmt.Sprintf("tickLower(%d) < MIN_TICK(%d)", tickLower, consts.MIN_TICK), + )) + } +} + +func assertValidTickUpper(tickUpper int32) { + if tickUpper > consts.MAX_TICK { + panic(addDetailToError( + errTickUpperInvalid, + ufmt.Sprintf("tickUpper(%d) > MAX_TICK(%d)", tickUpper, consts.MAX_TICK), + )) + } +} diff --git a/contract/r/gnoswap/pool/position_modify_test.gno b/contract/r/gnoswap/pool/position_modify_test.gno new file mode 100644 index 000000000..1debe114e --- /dev/null +++ b/contract/r/gnoswap/pool/position_modify_test.gno @@ -0,0 +1,292 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/common" +) + +func TestModifyPosition(t *testing.T) { + const ( + fee500 = uint32(500) + test_liquidityDelta = int64(100000000) + ) + + tests := []struct { + name string + sqrtPrice string + tickLower int32 + tickUpper int32 + expectedAmt0 string + expectedAmt1 string + }{ + { + name: "current price is lower than range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-12000).ToString(), + tickLower: -11000, + tickUpper: -9000, + expectedAmt0: "16492846", + expectedAmt1: "0", + }, + { + name: "current price is in range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-10000).ToString(), + tickLower: -11000, + tickUpper: -9000, + expectedAmt0: "8040316", + expectedAmt1: "2958015", + }, + { + name: "current price is higher than range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-8000).ToString(), + tickLower: -11000, + tickUpper: -9000, + expectedAmt0: "0", + expectedAmt1: "6067683", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sqrtPrice := u256.MustFromDecimal(tt.sqrtPrice) + poolParams := newPoolParams( + barPath, + fooPath, + fee500, + sqrtPrice.ToString(), + ) + pool := newPool(poolParams) + + params := ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: tt.tickLower, + tickUpper: tt.tickUpper, + liquidityDelta: new(i256.Int).SetInt64(test_liquidityDelta), + } + + _, amount0, amount1 := pool.modifyPosition(params) + + uassert.Equal(t, amount0.ToString(), tt.expectedAmt0) + uassert.Equal(t, amount1.ToString(), tt.expectedAmt1) + }) + } +} + +func TestModifyPositionEdgeCases(t *testing.T) { + const fee500 = uint32(500) + sqrtPrice := common.TickMathGetSqrtRatioAtTick(-10000).ToString() + + t.Run("liquidityDelta is zero", func(t *testing.T) { + poolParams := newPoolParams( + barPath, + fooPath, + fee500, + sqrtPrice, + ) + pool := newPool(poolParams) + params := ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -11000, + tickUpper: -9000, + liquidityDelta: i256.Zero(), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic. not happened") + } else { + expectedError := "[GNOSWAP-POOL-010] zero liquidity" + if err, ok := r.(string); !ok || err[:len(expectedError)] != expectedError { + t.Errorf("expected error message. got: %v, want: %v", r, expectedError) + } + } + }() + + pool.modifyPosition(params) + }) + + t.Run("liquidityDelta is negative", func(t *testing.T) { + poolParams := newPoolParams( + barPath, + fooPath, + fee500, + sqrtPrice, + ) + pool := newPool(poolParams) + params := ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -11000, + tickUpper: -9000, + liquidityDelta: i256.MustFromDecimal("100000000"), + } + pool.modifyPosition(params) + + // remove liquidity + params.liquidityDelta = i256.MustFromDecimal("-100000000") + _, amount0, amount1 := pool.modifyPosition(params) + + // remove amount should be same as added amount + uassert.Equal(t, amount0.ToString(), "8040315") + uassert.Equal(t, amount1.ToString(), "2958014") + }) +} + +func TestAssertTickLowerLessThanUpper(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickLower is less than tickUpper", + tickLower: -100, + tickUpper: 100, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower equals tickUpper", + tickLower: 50, + tickUpper: 50, + shouldPanic: true, + expected: "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(50), tickUpper(50)", + }, + { + name: "tickLower greater than tickUpper", + tickLower: 200, + tickUpper: 100, + shouldPanic: true, + expected: "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(200), tickUpper(100)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertTickLowerLessThanUpper(tt.tickLower, tt.tickUpper) + }) + } +} + +func TestAssertValidTickLower(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickLower equals MIN_TICK", + tickLower: consts.MIN_TICK, + tickUpper: 100, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower greater than MIN_TICK", + tickLower: consts.MIN_TICK + 1, + tickUpper: 50, + shouldPanic: false, + expected: "", + }, + { + name: "tickLower less than MIN_TICK (panic)", + tickLower: consts.MIN_TICK - 1, + tickUpper: 100, + shouldPanic: true, + expected: "[GNOSWAP-POOL-028] tickLower is invalid || tickLower(-887273) < MIN_TICK(-887272)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertValidTickLower(tt.tickLower) + }) + } +} + +func TestAssertValidTickUpper(t *testing.T) { + tests := []struct { + name string + tickLower int32 + tickUpper int32 + shouldPanic bool + expected string + }{ + { + name: "tickUpper equals MAX_TICK", + tickLower: consts.MIN_TICK, + tickUpper: consts.MAX_TICK, + shouldPanic: false, + expected: "", + }, + { + name: "tickUpper less than MAX_TICK", + tickLower: consts.MIN_TICK + 1, + tickUpper: consts.MAX_TICK - 1, + shouldPanic: false, + expected: "", + }, + { + name: "tickUpper greater than MAX_TICK (panic)", + tickLower: consts.MIN_TICK - 1, + tickUpper: consts.MAX_TICK + 1, + shouldPanic: true, + expected: "[GNOSWAP-POOL-029] tickUpper is invalid || tickUpper(887273) > MAX_TICK(887272)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } else { + if tt.shouldPanic { + uassert.Equal(t, tt.expected, r.(string)) + } else { + t.Errorf("expected no panic, but got: %v", r) + } + } + }() + assertValidTickUpper(tt.tickUpper) + }) + } +} diff --git a/contract/r/gnoswap/pool/position_test.gno b/contract/r/gnoswap/pool/position_test.gno new file mode 100644 index 000000000..03ab0987b --- /dev/null +++ b/contract/r/gnoswap/pool/position_test.gno @@ -0,0 +1,176 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" +) + +func TestPositionGetKey(t *testing.T) { + invalidAddr := std.Address("invalidAddr") + validAddr := testutils.TestAddress("validAddr") + + tests := []struct { + owner std.Address + tickLower int32 + tickUpper int32 + shouldPanic bool + panicMsg string + expectedKey string + }{ + {invalidAddr, 100, 200, true, `[GNOSWAP-POOL-023] invalid address || (invalidAddr)`, ""}, // invalid address + {validAddr, 200, 100, true, `[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(200), tickUpper(100)`, ""}, // tickLower > tickUpper + {validAddr, -100, -200, true, `[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(-100), tickUpper(-200)`, ""}, // tickLower > tickUpper + {validAddr, 100, 100, true, "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(100), tickUpper(100)", ""}, // tickLower == tickUpper + {validAddr, 100, 200, false, "", "ZzF3ZXNrYzZ0eWc5anhndWpsdGEwNDdoNmx0YTA0N2g2bGRqbHVkdV9fMTAwX18yMDA="}, // tickLower < tickUpper + } + + for _, tc := range tests { + defer func() { + if r := recover(); r != nil { + uassert.Equal(t, tc.panicMsg, r.(string)) + } + }() + if tc.shouldPanic { + uassert.PanicsWithMessage(t, tc.panicMsg, func() { getPositionKey(tc.owner, tc.tickLower, tc.tickUpper) }) + } else { + key := getPositionKey(tc.owner, tc.tickLower, tc.tickUpper) + uassert.Equal(t, tc.expectedKey, key) + } + } +} + +func TestPositionUpdateWithKey(t *testing.T) { + var dummyPool *Pool + var positionKey string + + t.Run("set up initial data for this test function", func(t *testing.T) { + poolParams := newPoolParams( + "token0", + "token1", + 100, + common.TickMathGetSqrtRatioAtTick(0).ToString(), + ) + dummyPool = newPool(poolParams) + + positionKey = getPositionKey( + testutils.TestAddress("dummyAddr"), + 100, + 200, + ) + }) + + tests := []struct { + liquidity *i256.Int + amount0 *u256.Uint + amount1 *u256.Uint + shouldPanic bool + panicMsg string + expectedLiquidity string + }{ + {i256.MustFromDecimal("0"), u256.Zero(), u256.Zero(), true, `[GNOSWAP-POOL-010] zero liquidity || both liquidityDelta and current position's liquidity are zero`, ""}, + {i256.MustFromDecimal("100000"), u256.Zero(), u256.Zero(), false, "", "100000"}, + } + + for _, tc := range tests { + if tc.shouldPanic { + uassert.PanicsWithMessage(t, tc.panicMsg, func() { dummyPool.positionUpdateWithKey(positionKey, tc.liquidity, tc.amount0, tc.amount1) }) + } else { + newPos := dummyPool.positionUpdateWithKey(positionKey, tc.liquidity, tc.amount0, tc.amount1) + uassert.Equal(t, newPos.liquidity.ToString(), tc.expectedLiquidity) + } + } +} + +func TestPositionUpdate(t *testing.T) { + tests := []struct { + initialLiquidity *u256.Uint + liquidityDelta *i256.Int + feeGrowthInside0X128 *u256.Uint + feeGrowthInside1X128 *u256.Uint + shouldPanic bool + panicMsg string + expectedLiquidity string + expectedFeeGrowthInside0X128 string + expectedFeeGrowthInside1X128 string + expectedToken0Owed string + expectedToken1Owed string + }{ + { + initialLiquidity: u256.Zero(), + liquidityDelta: i256.MustFromDecimal("0"), + feeGrowthInside0X128: u256.Zero(), + feeGrowthInside1X128: u256.Zero(), + shouldPanic: true, + panicMsg: `[GNOSWAP-POOL-010] zero liquidity || both liquidityDelta and current position's liquidity are zero`, + }, + { + initialLiquidity: u256.Zero(), + liquidityDelta: i256.MustFromDecimal("100000"), + feeGrowthInside0X128: u256.Zero(), + feeGrowthInside1X128: u256.Zero(), + expectedLiquidity: "100000", + expectedFeeGrowthInside0X128: "0", + expectedFeeGrowthInside1X128: "0", + expectedToken0Owed: "0", + expectedToken1Owed: "0", + }, + { + initialLiquidity: u256.Zero(), + liquidityDelta: i256.MustFromDecimal("100000"), + feeGrowthInside0X128: u256.MustFromDecimal("100000000"), + feeGrowthInside1X128: u256.MustFromDecimal("100000000"), + expectedLiquidity: "100000", + expectedFeeGrowthInside0X128: "100000000", + expectedFeeGrowthInside1X128: "100000000", + expectedToken0Owed: "0", + expectedToken1Owed: "0", + }, + { + initialLiquidity: u256.NewUint(100000), + liquidityDelta: i256.MustFromDecimal("100000"), + feeGrowthInside0X128: u256.MustFromDecimal("100000000"), + feeGrowthInside1X128: u256.MustFromDecimal("100000000"), + expectedLiquidity: "200000", + expectedFeeGrowthInside0X128: "100000000", + expectedFeeGrowthInside1X128: "100000000", + expectedToken0Owed: "0", + expectedToken1Owed: "0", + }, + { + initialLiquidity: u256.MustFromDecimal("340282366920938463463374607431768211456"), // Q128 value + liquidityDelta: i256.Zero(), + feeGrowthInside0X128: u256.MustFromDecimal("100000000"), + feeGrowthInside1X128: u256.MustFromDecimal("200000000"), + expectedLiquidity: "340282366920938463463374607431768211456", + expectedFeeGrowthInside0X128: "100000000", + expectedFeeGrowthInside1X128: "200000000", + expectedToken0Owed: "100000000", + expectedToken1Owed: "200000000", + }, + } + + for _, tc := range tests { + position := PositionInfo{ + liquidity: tc.initialLiquidity, + } + + if tc.shouldPanic { + uassert.PanicsWithMessage(t, tc.panicMsg, func() { positionUpdate(position, tc.liquidityDelta, tc.feeGrowthInside0X128, tc.feeGrowthInside1X128) }) + } else { + newPos := positionUpdate(position, tc.liquidityDelta, tc.feeGrowthInside0X128, tc.feeGrowthInside1X128) + uassert.Equal(t, newPos.liquidity.ToString(), tc.expectedLiquidity) + uassert.Equal(t, newPos.feeGrowthInside0LastX128.ToString(), tc.expectedFeeGrowthInside0X128) + uassert.Equal(t, newPos.feeGrowthInside1LastX128.ToString(), tc.expectedFeeGrowthInside1X128) + uassert.Equal(t, newPos.tokensOwed0.ToString(), tc.expectedToken0Owed) + uassert.Equal(t, newPos.tokensOwed1.ToString(), tc.expectedToken1Owed) + } + } +} diff --git a/contract/r/gnoswap/pool/position_update.gno b/contract/r/gnoswap/pool/position_update.gno new file mode 100644 index 000000000..9e3fe0304 --- /dev/null +++ b/contract/r/gnoswap/pool/position_update.gno @@ -0,0 +1,99 @@ +package pool + +// updatePosition modifies the position's liquidity and updates the corresponding tick states. +// +// This function updates the position data based on the specified liquidity delta and tick range. +// It also manages the fee growth, tick state flipping, and cleanup of unused tick data. +// +// Parameters: +// - positionParams: ModifyPositionParams, the parameters for the position modification, which include: +// - owner: The address of the position owner. +// - tickLower: The lower tick boundary of the position. +// - tickUpper: The upper tick boundary of the position. +// - liquidityDelta: The change in liquidity (positive or negative). +// - tick: int32, the current tick position. +// +// Returns: +// - PositionInfo: The updated position information. +// +// Workflow: +// 1. Clone the global fee growth values (token 0 and token 1). +// 2. If the liquidity delta is non-zero: +// - Update the lower and upper ticks using `tickUpdate`, flipping their states if necessary. +// - If a tick's state was flipped, update the tick bitmap to reflect the new state. +// 3. Calculate the fee growth inside the tick range using `getFeeGrowthInside`. +// 4. Generate a unique position key and update the position data using `positionUpdateWithKey`. +// 5. If liquidity is being removed (negative delta), clean up unused tick data by deleting the tick entries. +// 6. Return the updated position. +// +// Notes: +// - The function flips the tick states and cleans up unused tick data when liquidity is removed. +// - It ensures fee growth and position data remain accurate after the update. +// +// Example Usage: +// +// updatedPosition := pool.updatePosition(positionParams, currentTick) +// fmt.Println("Updated Position Info:", updatedPosition) +func (p *Pool) updatePosition(positionParams ModifyPositionParams, tick int32) PositionInfo { + feeGrowthGlobal0X128 := p.FeeGrowthGlobal0X128().Clone() + feeGrowthGlobal1X128 := p.FeeGrowthGlobal1X128().Clone() + + var flippedLower, flippedUpper bool + if !(positionParams.liquidityDelta.IsZero()) { + flippedLower = p.tickUpdate( + positionParams.tickLower, + tick, + positionParams.liquidityDelta, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, + false, + p.maxLiquidityPerTick, + ) + + flippedUpper = p.tickUpdate( + positionParams.tickUpper, + tick, + positionParams.liquidityDelta, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, + true, + p.maxLiquidityPerTick, + ) + + if flippedLower { + p.tickBitmapFlipTick(positionParams.tickLower, p.tickSpacing) + } + + if flippedUpper { + p.tickBitmapFlipTick(positionParams.tickUpper, p.tickSpacing) + } + } + + feeGrowthInside0X128, feeGrowthInside1X128 := p.getFeeGrowthInside( + positionParams.tickLower, + positionParams.tickUpper, + tick, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, + ) + + positionKey := getPositionKey(positionParams.owner, positionParams.tickLower, positionParams.tickUpper) + position := p.positionUpdateWithKey( + positionKey, + positionParams.liquidityDelta, + feeGrowthInside0X128.Clone(), + feeGrowthInside1X128.Clone(), + ) + + // clear any tick data that is no longer needed + if positionParams.liquidityDelta.IsNeg() { + if flippedLower { + p.deleteTick(positionParams.tickLower) + } + if flippedUpper { + p.deleteTick(positionParams.tickUpper) + } + } + + return position +} diff --git a/contract/r/gnoswap/pool/position_update_test.gno b/contract/r/gnoswap/pool/position_update_test.gno new file mode 100644 index 000000000..22ced7640 --- /dev/null +++ b/contract/r/gnoswap/pool/position_update_test.gno @@ -0,0 +1,82 @@ +package pool + +import ( + "testing" + + "gno.land/p/gnoswap/consts" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestUpdatePosition(t *testing.T) { + poolParams := &createPoolParams{ + token0Path: "token0", + token1Path: "token1", + fee: 500, + tickSpacing: 10, + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + } + p := newPool(poolParams) + + tests := []struct { + name string + positionParams ModifyPositionParams + expectLiquidity *u256.Uint + }{ + { + name: "add new position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("1000000"), + }, + expectLiquidity: u256.MustFromDecimal("1000000"), + }, + { + name: "add liquidity to existing position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("500000"), + }, + expectLiquidity: u256.MustFromDecimal("1500000"), + }, + { + name: "remove liquidity from position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("-500000"), + }, + expectLiquidity: u256.MustFromDecimal("1000000"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tick := p.Slot0Tick() + position := p.updatePosition(tt.positionParams, tick) + + if !position.liquidity.Eq(tt.expectLiquidity) { + t.Errorf("liquidity mismatch: expected %s, got %s", + tt.expectLiquidity.ToString(), + position.liquidity.ToString()) + } + + if !tt.positionParams.liquidityDelta.IsZero() { + lowerTick := p.mustGetTick(tt.positionParams.tickLower) + upperTick := p.mustGetTick(tt.positionParams.tickUpper) + + if !lowerTick.initialized { + t.Error("lower tick not initialized") + } + if !upperTick.initialized { + t.Error("upper tick not initialized") + } + } + }) + } +} diff --git a/contract/r/gnoswap/pool/protocol_fee_pool_creation.gno b/contract/r/gnoswap/pool/protocol_fee_pool_creation.gno new file mode 100644 index 000000000..0f4106d73 --- /dev/null +++ b/contract/r/gnoswap/pool/protocol_fee_pool_creation.gno @@ -0,0 +1,69 @@ +package pool + +import ( + "std" + + "gno.land/r/gnoswap/v1/common" +) + +// poolCreationFee is the fee that is charged when a user creates a pool. +// The fee is denominated in GNS tokens. +var ( + poolCreationFee = uint64(100_000_000) // 100_GNS +) + +// GetPoolCreationFee returns the poolCreationFee +// ref: https://docs.gnoswap.io/contracts/pool/protocol_fee_pool_creation.gno#getpoolcreationfee +func GetPoolCreationFee() uint64 { + return poolCreationFee +} + +// SetPoolCreationFee sets the poolCreationFee. +// Only governance contract can execute this function via proposal +// ref: https://docs.gnoswap.io/contracts/pool/protocol_fee_pool_creation.gno#setpoolcreationfee +func SetPoolCreationFee(fee uint64) { + assertOnlyNotHalted() + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err.Error()) + } + prevPoolCreationFee := poolCreationFee + setPoolCreationFee(fee) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "SetPoolCreationFee", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevFee", formatUint(prevPoolCreationFee), + "newFee", formatUint(fee), + ) +} + +// SetPoolCreationFeeByAdmin sets the poolCreationFee by Admin. +// Only admin can execute this function. +func SetPoolCreationFeeByAdmin(fee uint64) { + assertOnlyNotHalted() + + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err.Error()) + } + prevPoolCreationFee := poolCreationFee + setPoolCreationFee(fee) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "SetPoolCreationFeeByAdmin", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevFee", formatUint(prevPoolCreationFee), + "newFee", formatUint(fee), + ) +} + +// setPoolCreationFee this function is internal function called by SetPoolCreationFee +// And SetPoolCreationFeeByAdmin +func setPoolCreationFee(fee uint64) { + poolCreationFee = fee +} diff --git a/contract/r/gnoswap/pool/protocol_fee_pool_creation_test.gno b/contract/r/gnoswap/pool/protocol_fee_pool_creation_test.gno new file mode 100644 index 000000000..44d4dc0e7 --- /dev/null +++ b/contract/r/gnoswap/pool/protocol_fee_pool_creation_test.gno @@ -0,0 +1,180 @@ +package pool + +import ( + "std" + "strconv" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" +) + +func TestSetPoolCreationFee(t *testing.T) { + var ( + admin = std.Address("g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8c") + ) + + tests := []struct { + name string + action func() + verify func() string + expected string + shouldPanic bool + }{ + { + name: "Panic call by non-governance", + action: func() { + const newFee = 2_000_000_000 + std.TestSetOrigCaller(admin) + SetPoolCreationFee(newFee) + }, + verify: nil, + expected: "caller(g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8c) has no permission", + shouldPanic: true, + }, + { + name: "Success call by governance", + action: func() { + const newFee = 2_000_000_000 + govRealm := std.NewUserRealm(consts.GOV_GOVERNANCE_ADDR) + std.TestSetRealm(govRealm) + SetPoolCreationFee(newFee) + }, + verify: func() string { + return strconv.FormatUint(GetPoolCreationFee(), 10) + }, + expected: strconv.FormatUint(2_000_000_000, 10), + shouldPanic: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + tc.action() + if tc.verify != nil { + got := tc.verify() + uassert.Equal(t, got, tc.expected) + } + } else { + tc.action() + } + }) + } +} + +func TestSetPoolCreationFeeByAdmin(t *testing.T) { + var ( + admin = consts.ADMIN + alice = testutils.TestAddress("alice") + ) + + tests := []struct { + name string + action func() + verify func() string + expected string + shouldPanic bool + }{ + { + name: "Panic call by non-admin (gov contract)", + action: func() { + const newFee = 2_000_000_000 + govRealm := std.NewUserRealm(consts.GOV_GOVERNANCE_ADDR) + std.TestSetRealm(govRealm) + { + SetPoolCreationFeeByAdmin(newFee) + } + }, + verify: nil, + expected: "caller(g17s8w2ve7k85fwfnrk59lmlhthkjdted8whvqxd) has no permission", + shouldPanic: true, + }, + { + name: "Panic call by non-admin (user)", + action: func() { + const newFee = 2_000_000_000 + std.TestSetOrigCaller(alice) + { + SetPoolCreationFeeByAdmin(newFee) + } + }, + verify: nil, + expected: "caller(g1v9kxjcm9ta047h6lta047h6lta047h6lzd40gh) has no permission", + shouldPanic: true, + }, + { + name: "Success call by admin", + action: func() { + const newFee = 2_000_000_000 + std.TestSetOrigCaller(admin) + { + SetPoolCreationFeeByAdmin(newFee) + } + }, + verify: func() string { + return strconv.FormatUint(GetPoolCreationFee(), 10) + }, + expected: strconv.FormatUint(2_000_000_000, 10), + shouldPanic: false, + }, + { + name: "Success call by admin (rollback)", + action: func() { + const newFee = 1_000_000_000 + std.TestSetOrigCaller(admin) + { + SetPoolCreationFeeByAdmin(newFee) + } + }, + verify: func() string { + return strconv.FormatUint(GetPoolCreationFee(), 10) + }, + expected: strconv.FormatUint(1_000_000_000, 10), + shouldPanic: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + tc.action() + if tc.verify != nil { + got := tc.verify() + uassert.Equal(t, got, tc.expected) + } + } else { + tc.action() + } + }) + } +} diff --git a/contract/r/gnoswap/pool/protocol_fee_withdrawal.gno b/contract/r/gnoswap/pool/protocol_fee_withdrawal.gno new file mode 100644 index 000000000..48cf76e81 --- /dev/null +++ b/contract/r/gnoswap/pool/protocol_fee_withdrawal.gno @@ -0,0 +1,173 @@ +package pool + +import ( + "std" + + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + pf "gno.land/r/gnoswap/v1/protocol_fee" +) + +// withdrawalFeeBPS is the fee that is charged when a user withdraws their collected fees +// The fee is denominated in BPS (Basis Points) +// Example: 100 BPS = 1% +var ( + withdrawalFeeBPS = uint64(100) // 1% +) + +const ( + MaxBpsValue = uint64(10000) + ZeroBps = uint64(0) +) + +// HandleWithdrawalFee withdraws the fee from the user and returns the amount after the fee +// Only position contract can call this function +// Input: +// - tokenId: the id of the LP token +// - token0Path: the path of the token0 +// - _amount0: the amount of token0 +// - token1Path: the path of the token1 +// - _amount1: the amount of token1 +// - poolPath: the path of the pool +// - positionCaller: the original caller of the position contract +// Output: +// - the amount of token0 after the fee +// - the amount of token1 after the fee +// +// ref: https://docs.gnoswap.io/contracts/pool/protocol_fee_withdrawal.gno#handlewithdrawalfee +func HandleWithdrawalFee( + tokenId uint64, + token0Path string, + amount0 string, // uint256 + token1Path string, + amount1 string, // uint256 + poolPath string, + positionCaller std.Address, +) (string, string) { // uint256 x2 + assertOnlyNotHalted() + common.MustRegistered(token0Path) + common.MustRegistered(token1Path) + + // only position contract can call this function + caller := getPrevAddr() + if err := common.PositionOnly(caller); err != nil { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("withdrawal_fee.gno__HandleWithdrawalFee() || only position(%s) can call this function, called from %s", consts.POSITION_ADDR, caller.String()), + )) + } + + fee := GetWithdrawalFee() + if fee == ZeroBps { + return amount0, amount1 + } + + feeAmount0, afterAmount0 := calculateAmountWithFee(u256.MustFromDecimal(amount0), u256.NewUint(fee)) + feeAmount1, afterAmount1 := calculateAmountWithFee(u256.MustFromDecimal(amount1), u256.NewUint(fee)) + + token0Teller := common.GetTokenTeller(token0Path) + checkTransferError(token0Teller.TransferFrom(positionCaller, consts.PROTOCOL_FEE_ADDR, feeAmount0.Uint64())) + pf.AddToProtocolFee(token0Path, feeAmount0.Uint64()) + + token1Teller := common.GetTokenTeller(token1Path) + checkTransferError(token1Teller.TransferFrom(positionCaller, consts.PROTOCOL_FEE_ADDR, feeAmount1.Uint64())) + pf.AddToProtocolFee(token1Path, feeAmount1.Uint64()) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "WithdrawalFee", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "lpTokenId", formatUint(tokenId), + "poolPath", poolPath, + "feeAmount0", feeAmount0.ToString(), + "feeAmount1", feeAmount1.ToString(), + ) + + return afterAmount0.ToString(), afterAmount1.ToString() +} + +// GetWithdrawalFee returns the withdrawal fee +// ref: https://docs.gnoswap.io/contracts/pool/protocol_fee_withdrawal.gno#getwithdrawalfee +func GetWithdrawalFee() uint64 { + return withdrawalFeeBPS +} + +// SetWithdrawalFee sets the withdrawal fee. +// Only governance contract can execute this function via proposal +// ref: https://docs.gnoswap.io/contracts/pool/protocol_fee_withdrawal.gno#setwithdrawalfee +func SetWithdrawalFee(fee uint64) { + assertOnlyNotHalted() + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err.Error()) + } + + prevWithdrawalFee := withdrawalFeeBPS + + setWithdrawalFee(fee) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "SetWithdrawalFee", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevFee", formatUint(prevWithdrawalFee), + "newFee", formatUint(fee), + ) +} + +// SetWithdrawalFeeByAdmin sets the withdrawal fee by Admin. +// Only admin can execute this function. +func SetWithdrawalFeeByAdmin(fee uint64) { + assertOnlyNotHalted() + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err.Error()) + } + + prevWithdrawalFee := withdrawalFeeBPS + + setWithdrawalFee(fee) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "SetWithdrawalFeeByAdmin", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevFee", formatUint(prevWithdrawalFee), + "newFee", formatUint(fee), + ) +} + +// setWithdrawalFee this function is internal function called by SetWithdrawalFee +// function and SetWithdrawalFeeByAdmin function +func setWithdrawalFee(fee uint64) { + // 10000 (bps) = 100% + if fee > MaxBpsValue { + panic(addDetailToError( + errInvalidWithdrawalFeePct, + ufmt.Sprintf("withdrawal_fee.gno__setWithdrawalFee() || fee(%d) must be in range 0 ~ 10000", fee), + )) + } + withdrawalFeeBPS = fee +} + +// calculateAmountWithFee calculates the fee amount and the amount after the fee +// +// Inputs: +// - amount: the amount before the fee +// - fee: the fee in BPS +// +// Outputs: +// - the fee amount +// - the amount after the fee applied +func calculateAmountWithFee(amount *u256.Uint, fee *u256.Uint) (*u256.Uint, *u256.Uint) { + feeAmount := new(u256.Uint).Mul(amount, fee) + feeAmount = new(u256.Uint).Div(feeAmount, u256.NewUint(MaxBpsValue)) + afterAmount := new(u256.Uint).Sub(amount, feeAmount) + return feeAmount, afterAmount +} diff --git a/contract/r/gnoswap/pool/protocol_fee_withdrawal_test.gno b/contract/r/gnoswap/pool/protocol_fee_withdrawal_test.gno new file mode 100644 index 000000000..633629946 --- /dev/null +++ b/contract/r/gnoswap/pool/protocol_fee_withdrawal_test.gno @@ -0,0 +1,333 @@ +package pool + +import ( + "std" + "strconv" + "strings" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" + pn "gno.land/r/gnoswap/v1/position" +) + +func TestHandleWithdrawalFee(t *testing.T) { + var ( + admin = consts.ADMIN + position = consts.POSITION_ADDR + pool = consts.POOL_ADDR + protocolFee = consts.PROTOCOL_FEE_ADDR + alice = testutils.TestAddress("alice") + ) + tests := []struct { + name string + action func(t *testing.T) + verify func(t *testing.T) (string, string) + expected string + shouldPanic bool + }{ + { + name: "Panic if caller is not position contract", + action: func(t *testing.T) { + std.TestSetOrigCaller(admin) + HandleWithdrawalFee(0, "gno.land/r/onbloc/foo", "0", "gno.land/r/onbloc/foo", "0", "", admin) + }, + verify: nil, + expected: "[GNOSWAP-POOL-001] caller has no permission || withdrawal_fee.gno__HandleWithdrawalFee() || only position(g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5) can call this function, called from g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + shouldPanic: true, + }, + { + name: "Panic if pkgPath is not registered", + action: func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(position)) + HandleWithdrawalFee(0, "pkgPath", "1000", "pkgPath", "1000", "poolPath", admin) + }, + verify: nil, + expected: "[GNOSWAP-COMMON-004] token is not registered || token(pkgPath)", + shouldPanic: true, + }, + { + name: "Panic if spender has no approved balance", + action: func(t *testing.T) { + InitialisePoolTest(t) + std.TestSetRealm(std.NewUserRealm(position)) + poolPath := GetPoolPath(wugnotPath, gnsPath, fee3000) + if !pools.Has(poolPath) { + panic("pool not found") + } + TokenApprove(t, wugnotPath, alice, protocolFee, uint64(0)) + TokenApprove(t, gnsPath, alice, protocolFee, uint64(0)) + HandleWithdrawalFee(1, wugnotPath, "1000", gnsPath, "1000", poolPath, alice) + }, + verify: nil, + expected: "[GNOSWAP-POOL-021] token transfer failed || insufficient allowance", + shouldPanic: true, + }, + { + name: "Success call by position contract", + action: func(t *testing.T) { + InitialisePoolTest(t) + std.TestSetRealm(std.NewUserRealm(alice)) + TokenApprove(t, wugnotPath, alice, pool, uint64(1000)) + TokenApprove(t, gnsPath, alice, pool, uint64(1000)) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(1020), + int32(5040), + "1000", + "1000", + "0", + "0", + max_timeout, + alice, + alice, + ) + }, + verify: func(t *testing.T) (string, string) { + std.TestSetRealm(std.NewUserRealm(alice)) + TokenApprove(t, wugnotPath, alice, pool, uint64(10000000000)) + TokenApprove(t, gnsPath, alice, pool, uint64(1000000000)) + std.TestSetRealm(std.NewUserRealm(position)) + poolPath := GetPoolPath(wugnotPath, gnsPath, fee3000) + return HandleWithdrawalFee(2, wugnotPath, "1000", gnsPath, "1000", poolPath, alice) + }, + expected: "990,990", + shouldPanic: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + tc.action(t) + if tc.verify != nil { + gotAfterAmount0, gotAfterAmount1 := tc.verify(t) + expected := strings.Split(tc.expected, ",") + uassert.Equal(t, gotAfterAmount0, expected[0]) + uassert.Equal(t, gotAfterAmount1, expected[1]) + } + } else { + tc.action(t) + } + }) + } +} + +func TestSetWithdrawalFee(t *testing.T) { + var ( + admin = consts.ADMIN + alice = testutils.TestAddress("alice") + ) + tests := []struct { + name string + action func(t *testing.T) + verify func(t *testing.T) string + expected string + shouldPanic bool + }{ + { + name: "Panic call to set withdrawal fee setFee by non-admin (user)", + action: func(t *testing.T) { + const newFee = uint64(200) + userRealm := std.NewUserRealm(alice) + std.TestSetRealm(userRealm) + { + SetWithdrawalFee(newFee) + } + }, + verify: nil, + expected: "caller(g1v9kxjcm9ta047h6lta047h6lta047h6lzd40gh) has no permission", + shouldPanic: true, + }, + { + name: "Panic call to set withdrawal fee by admin", + action: func(t *testing.T) { + const newFee = uint64(200) + adminRealm := std.NewUserRealm(admin) + std.TestSetRealm(adminRealm) + { + SetWithdrawalFee(newFee) + } + }, + verify: nil, + expected: "caller(g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d) has no permission", + shouldPanic: true, + }, + { + name: "Success call to set withdrawal fee by governance", + action: func(t *testing.T) { + const newFee = uint64(200) + govRealm := std.NewUserRealm(consts.GOV_GOVERNANCE_ADDR) + std.TestSetRealm(govRealm) + { + SetWithdrawalFee(newFee) + } + }, + verify: func(t *testing.T) string { + return strconv.FormatUint(GetWithdrawalFee(), 10) + }, + expected: strconv.FormatUint(200, 10), + shouldPanic: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + tc.action(t) + if tc.verify != nil { + gotWithdrawalFee := tc.verify(t) + uassert.Equal(t, gotWithdrawalFee, tc.expected) + } + } else { + tc.action(t) + } + }) + } +} + +func TestSetWithdrawalFeeByAdmin(t *testing.T) { + var ( + admin = consts.ADMIN + alice = testutils.TestAddress("alice") + governance = consts.GOV_GOVERNANCE_ADDR + ) + tests := []struct { + name string + action func(t *testing.T) + verify func(t *testing.T) string + expected string + shouldPanic bool + }{ + { + name: "Panic call to set withdrawal fee by non-admin (user)", + action: func(t *testing.T) { + const newFee = uint64(100) + userRealm := std.NewUserRealm(alice) + std.TestSetRealm(userRealm) + { + SetWithdrawalFeeByAdmin(newFee) + } + }, + verify: nil, + expected: "caller(g1v9kxjcm9ta047h6lta047h6lta047h6lzd40gh) has no permission", + shouldPanic: true, + }, + { + name: "Panic call to set withdrawal fee by non-admin (gov contract)", + action: func(t *testing.T) { + const newFee = uint64(100) + govRealm := std.NewUserRealm(governance) + std.TestSetRealm(govRealm) + { + SetWithdrawalFeeByAdmin(newFee) + } + }, + verify: nil, + expected: "caller(g17s8w2ve7k85fwfnrk59lmlhthkjdted8whvqxd) has no permission", + shouldPanic: true, + }, + { + name: "Success call to set withdrawal fee by admin", + action: func(t *testing.T) { + const newFee = uint64(100) + adminRealm := std.NewUserRealm(admin) + std.TestSetRealm(adminRealm) + { + SetWithdrawalFeeByAdmin(newFee) + } + }, + verify: func(t *testing.T) string { + return strconv.FormatUint(GetWithdrawalFee(), 10) + }, + expected: strconv.FormatUint(100, 10), + shouldPanic: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + tc.action(t) + if tc.verify != nil { + gotWithdrawalFee := tc.verify(t) + uassert.Equal(t, gotWithdrawalFee, tc.expected) + } + } else { + tc.action(t) + } + }) + } +} diff --git a/contract/r/gnoswap/pool/swap.gno b/contract/r/gnoswap/pool/swap.gno new file mode 100644 index 000000000..f0ca0f8db --- /dev/null +++ b/contract/r/gnoswap/pool/swap.gno @@ -0,0 +1,535 @@ +package pool + +import ( + "std" + + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + plp "gno.land/p/gnoswap/gnsmath" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +// SwapResult encapsulates all state changes that occur as a result of a swap +// This type ensure all state transitions are atomic and can be applied at once. +type SwapResult struct { + Amount0 *i256.Int + Amount1 *i256.Int + NewSqrtPrice *u256.Uint + NewTick int32 + NewLiquidity *u256.Uint + NewProtocolFees ProtocolFees + FeeGrowthGlobal0X128 *u256.Uint + FeeGrowthGlobal1X128 *u256.Uint +} + +// SwapComputation encapsulates pure computation logic for swap +type SwapComputation struct { + AmountSpecified *i256.Int + SqrtPriceLimitX96 *u256.Uint + ZeroForOne bool + ExactInput bool + InitialState SwapState + Cache SwapCache +} + +var ( + fixedPointQ128 = u256.MustFromDecimal(consts.Q128) +) + +// Swap swaps token0 for token1, or token1 for token0 +// Returns swapped amount0, amount1 in string +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap +func Swap( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + zeroForOne bool, + amountSpecified string, + sqrtPriceLimitX96 string, + payer std.Address, // router +) (string, string) { + assertOnlyNotHalted() + if common.GetLimitCaller() { + caller := getPrevAddr() + if err := common.RouterOnly(caller); err != nil { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), + )) + } + } + + if amountSpecified == "0" { + panic(addDetailToError( + errInvalidSwapAmount, + ufmt.Sprintf("amountSpecified == 0"), + )) + } + + pool := GetPool(token0Path, token1Path, fee) + + slot0Start := pool.slot0 + if !slot0Start.unlocked { + panic(errLockedPool) + } + pool.slot0.unlocked = false + defer func() { pool.slot0.unlocked = true }() + + sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) + validatePriceLimits(slot0Start, zeroForOne, sqrtPriceLimit) + + amounts := i256.MustFromDecimal(amountSpecified) + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) + cache := newSwapCache(feeProtocol, pool.liquidity.Clone()) + state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, slot0Start) + + comp := SwapComputation{ + AmountSpecified: amounts, + SqrtPriceLimitX96: sqrtPriceLimit, + ZeroForOne: zeroForOne, + ExactInput: amounts.Gt(i256.Zero()), + InitialState: state, + Cache: cache, + } + + result, err := computeSwap(pool, comp) + if err != nil { + panic(err) + } + + applySwapResult(pool, result) + + // actual swap + pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) + + prevAddr, prevPkgPath := getPrevAsString() + + std.Emit( + "Swap", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "poolPath", GetPoolPath(token0Path, token1Path, fee), + "zeroForOne", formatBool(zeroForOne), + "requestAmount", amountSpecified, + "sqrtPriceLimitX96", sqrtPriceLimitX96, + "payer", payer.String(), + "recipient", recipient.String(), + "token0Amount", result.Amount0.ToString(), + "token1Amount", result.Amount1.ToString(), + "protocolFee0", pool.protocolFees.token0.ToString(), + "protocolFee1", pool.protocolFees.token1.ToString(), + "sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), + "exactIn", strconv.FormatBool(comp.ExactInput), + "currentTick", strconv.FormatInt(int64(pool.Slot0Tick()), 10), + "liquidity", pool.Liquidity().ToString(), + "feeGrowthGlobal0X128", pool.FeeGrowthGlobal0X128().ToString(), + "feeGrowthGlobal1X128", pool.FeeGrowthGlobal1X128().ToString(), + "balanceToken0", pool.BalanceToken0().ToString(), + "balanceToken1", pool.BalanceToken1().ToString(), + ) + + return result.Amount0.ToString(), result.Amount1.ToString() +} + +// DrySwap simulates a swap and returns the amount0, amount1 that would be received and a boolean indicating if the swap is possible +func DrySwap( + token0Path string, + token1Path string, + fee uint32, + zeroForOne bool, + amountSpecified string, + sqrtPriceLimitX96 string, +) (string, string, bool) { + if amountSpecified == "0" { + return "0", "0", false + } + + pool := GetPool(token0Path, token1Path, fee) + + slot0Start := pool.slot0 + sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) + validatePriceLimits(slot0Start, zeroForOne, sqrtPriceLimit) + + amounts := i256.MustFromDecimal(amountSpecified) + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) + cache := newSwapCache(feeProtocol, pool.liquidity.Clone()) + state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, slot0Start) + + comp := SwapComputation{ + AmountSpecified: amounts, + SqrtPriceLimitX96: sqrtPriceLimit, + ZeroForOne: zeroForOne, + ExactInput: amounts.Gt(i256.Zero()), + InitialState: state, + Cache: cache, + } + + result, err := computeSwap(pool, comp) + if err != nil { + return "0", "0", false + } + + if zeroForOne { + if pool.balances.token1.Lt(result.Amount1.Abs()) { + return "0", "0", false + } + } else { + if pool.balances.token0.Lt(result.Amount0.Abs()) { + return "0", "0", false + } + } + + // Validate non-zero amounts + if result.Amount0.IsZero() || result.Amount1.IsZero() { + return "0", "0", false + } + + return result.Amount0.ToString(), result.Amount1.ToString(), true +} + +// computeSwap performs the core swap computation without modifying pool state +// The function follows these state transitions: +// 1. Initial State: Provided by `SwapComputation.InitialState` +// 2. Stepping State: For each step: +// - Compute next tick and price target +// - Calculate amounts and fees +// - Update state (remaining amount, fees, liquidity) +// - Handle tick transitions if necessary +// +// 3. Final State: Aggregated in SwapResult +// +// The computation continues until either: +// - The entire amount is consumed (`amountSpecifiedRemaining` = 0) +// - The price limit is reached (`sqrtPriceX96` = `sqrtPriceLimitX96`) +// +// Returns an error if the computation fails at any step +func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { + state := comp.InitialState + var err error + + // Compute swap steps until completion + for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { + state, err = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache) + if err != nil { + return nil, err + } + } + + // Calculate final amounts + amount0 := state.amountCalculated + amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) + if comp.ZeroForOne == comp.ExactInput { + amount0, amount1 = amount1, amount0 + } + + // Prepare result + result := &SwapResult{ + Amount0: amount0, + Amount1: amount1, + NewSqrtPrice: state.sqrtPriceX96, + NewTick: state.tick, + NewLiquidity: state.liquidity, + NewProtocolFees: ProtocolFees{ + token0: pool.protocolFees.token0, + token1: pool.protocolFees.token1, + }, + FeeGrowthGlobal0X128: pool.feeGrowthGlobal0X128, + FeeGrowthGlobal1X128: pool.feeGrowthGlobal1X128, + } + + // Update protocol fees if necessary + if comp.ZeroForOne { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) + } + result.FeeGrowthGlobal0X128 = state.feeGrowthGlobalX128 + } else { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) + } + result.FeeGrowthGlobal1X128 = state.feeGrowthGlobalX128 + } + + return result, nil +} + +// applySwapResult updates pool state with computed results. +// All state changes are applied at once to maintain consistency +func applySwapResult(pool *Pool, result *SwapResult) { + pool.slot0.sqrtPriceX96 = result.NewSqrtPrice + pool.slot0.tick = result.NewTick + pool.liquidity = result.NewLiquidity + pool.protocolFees = result.NewProtocolFees + pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0X128 + pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1X128 +} + +// validatePriceLimits ensures the provided price limit is valid for the swap direction +// The function enforces that: +// For zeroForOne (selling token0): +// - Price limit must be below current price +// - Price limit must be above MIN_SQRT_RATIO +// +// For !zeroForOne (selling token1): +// - Price limit must be above current price +// - Price limit must be below MAX_SQRT_RATIO +func validatePriceLimits(slot0 Slot0, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { + if zeroForOne { + minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Lt(slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MIN_SQRT_RATIO), + )) + } + } else { + maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Gt(slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MAX_SQRT_RATIO), + )) + } + } +} + +// getFeeProtocol returns the appropriate fee protocol based on zero for one +func getFeeProtocol(slot0 Slot0, zeroForOne bool) uint8 { + if zeroForOne { + return slot0.feeProtocol % 16 + } + return slot0.feeProtocol / 16 +} + +// getFeeGrowthGlobal returns the appropriate fee growth global based on zero for one +func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { + if zeroForOne { + return pool.feeGrowthGlobal0X128.Clone() + } + return pool.feeGrowthGlobal1X128.Clone() +} + +func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { + return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) +} + +// computeSwapStep executes a single step of swap and returns new state +func computeSwapStep( + state SwapState, + pool *Pool, + zeroForOne bool, + sqrtPriceLimitX96 *u256.Uint, + exactInput bool, + cache SwapCache, +) (SwapState, error) { + step := computeSwapStepInit(state, pool, zeroForOne) + + // determining the price target for this step + sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) + + // computing the amounts to be swapped at this step + var newState SwapState + var err error + + newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) + newState = updateAmounts(step, newState, exactInput) + + // if the protocol fee is on, calculate how much is owed, + // decrement fee amount, and increment protocol fee + if cache.feeProtocol > 0 { + newState, step, err = updateFeeProtocol(step, cache.feeProtocol, newState) + if err != nil { + return state, err + } + } + + // update global fee tracker + if newState.liquidity.Gt(u256.Zero()) { + update := u256.MulDiv(step.feeAmount, fixedPointQ128, newState.liquidity) + newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) + } + + // handling tick transitions + if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { + newState = tickTransition(step, zeroForOne, newState, pool) + } else if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { + newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) + } + + return newState, nil +} + +// updateFeeProtocol calculates and updates protocol fees for the current step. +func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) (SwapState, StepComputations, error) { + delta := step.feeAmount + delta.Div(delta, u256.NewUint(uint64(feeProtocol))) + + newFeeAmount, overflow := new(u256.Uint).SubOverflow(step.feeAmount, delta) + if overflow { + return state, step, errUnderflow + } + step.feeAmount = newFeeAmount + state.protocolFee.Add(state.protocolFee, delta) + + return state, step, nil +} + +// computeSwapStepInit initializes the computation for a single swap step. +func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepComputations { + var step StepComputations + step.sqrtPriceStartX96 = state.sqrtPriceX96 + tickNext, initialized := pool.tickBitmapNextInitializedTickWithInOneWord( + state.tick, + pool.tickSpacing, + zeroForOne, + ) + + step.tickNext = tickNext + step.initialized = initialized + + // prevent overshoot the min/max tick + step.clampTickNext() + + // get the price for the next tick + step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) + return step +} + +// computeTargetSqrtRatio determines the target sqrt price for the current swap step. +func computeTargetSqrtRatio(step StepComputations, sqrtPriceLimitX96 *u256.Uint, zeroForOne bool) *u256.Uint { + if shouldUsePriceLimit(step.sqrtPriceNextX96, sqrtPriceLimitX96, zeroForOne) { + return sqrtPriceLimitX96 + } + return step.sqrtPriceNextX96 +} + +// shouldUsePriceLimit returns true if the price limit should be used instead of the next tick price +func shouldUsePriceLimit(sqrtPriceNext, sqrtPriceLimit *u256.Uint, zeroForOne bool) bool { + isLower := sqrtPriceNext.Lt(sqrtPriceLimit) + isHigher := sqrtPriceNext.Gt(sqrtPriceLimit) + if zeroForOne { + return isLower + } + return isHigher +} + +// computeAmounts calculates the input and output amounts for the current swap step. +func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, step StepComputations) (SwapState, StepComputations) { + sqrtPriceX96Str, amountInStr, amountOutStr, feeAmountStr := plp.SwapMathComputeSwapStepStr( + state.sqrtPriceX96, + sqrtRatioTargetX96, + state.liquidity, + state.amountSpecifiedRemaining, + uint64(pool.fee), + ) + + step.amountIn = u256.MustFromDecimal(amountInStr) + step.amountOut = u256.MustFromDecimal(amountOutStr) + step.feeAmount = u256.MustFromDecimal(feeAmountStr) + + state.SetSqrtPriceX96(sqrtPriceX96Str) + + return state, step +} + +// updateAmounts calculates new remaining and calculated amounts based on the swap step +// For exact input swaps: +// - Decrements remaining input amount by (amountIn + feeAmount) +// - Decrements calculated amount by amountOut +// +// For exact output swaps: +// - Increments remaining output amount by amountOut +// - Increments calculated amount by (amountIn + feeAmount) +func updateAmounts(step StepComputations, state SwapState, exactInput bool) SwapState { + amountInWithFeeU256 := new(u256.Uint).Add(step.amountIn, step.feeAmount) + if amountInWithFeeU256.Gt(u256.MustFromDecimal(consts.MAX_INT256)) { + panic("amountIn + feeAmount overflows int256") + } + amountInWithFee := i256.FromUint256(amountInWithFeeU256) + if step.amountOut.Gt(u256.MustFromDecimal(consts.MAX_INT256)) { + panic("amountOut overflows int256") + } + + if exactInput { + state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) + state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) + return state + } else { + state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) + state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) + } + + return state +} + +// tickTransition handles the transition between price ticks during a swap +func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { + // ensure existing state to keep immutability + newState := state + + if step.initialized { + fee0, fee1 := u256.Zero(), u256.Zero() + + if zeroForOne { + fee0 = state.feeGrowthGlobalX128 + fee1 = pool.feeGrowthGlobal1X128 + } else { + fee0 = pool.feeGrowthGlobal0X128 + fee1 = state.feeGrowthGlobalX128 + } + + liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) + + if zeroForOne { + liquidityNet = i256.Zero().Neg(liquidityNet) + } + + newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) + + if tickCrossHook != nil { + tickCrossHook(pool.PoolPath(), step.tickNext, zeroForOne) + } + } + + if zeroForOne { + newState.tick = step.tickNext - 1 + } else { + newState.tick = step.tickNext + } + + return newState +} + +func (p *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { + if zeroForOne { + // payer > POOL + p.safeTransferFrom(payer, consts.POOL_ADDR, p.token0Path, amount0.Abs(), true) + // POOL > recipient + p.safeTransfer(recipient, p.token1Path, amount1, false) + } else { + // payer > POOL + p.safeTransferFrom(payer, consts.POOL_ADDR, p.token1Path, amount1.Abs(), false) + // POOL > recipient + p.safeTransfer(recipient, p.token0Path, amount0, true) + } +} diff --git a/contract/r/gnoswap/pool/swap_test.gno b/contract/r/gnoswap/pool/swap_test.gno new file mode 100644 index 000000000..13f9c8e62 --- /dev/null +++ b/contract/r/gnoswap/pool/swap_test.gno @@ -0,0 +1,951 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/avl" + "gno.land/p/demo/uassert" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" +) + +func TestSaveProtocolFees(t *testing.T) { + tests := []struct { + name string + pool *Pool + amount0 *u256.Uint + amount1 *u256.Uint + want0 *u256.Uint + want1 *u256.Uint + wantFee0 *u256.Uint + wantFee1 *u256.Uint + }{ + { + name: "normal fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(500), + amount1: u256.NewUint(1000), + want0: u256.NewUint(500), + want1: u256.NewUint(1000), + wantFee0: u256.NewUint(500), + wantFee1: u256.NewUint(1000), + }, + { + name: "exact fee deduction (1 deduction)", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(1000), + amount1: u256.NewUint(2000), + want0: u256.NewUint(999), + want1: u256.NewUint(1999), + wantFee0: u256.NewUint(1), + wantFee1: u256.NewUint(1), + }, + { + name: "0 fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(0), + amount1: u256.NewUint(0), + want0: u256.NewUint(0), + want1: u256.NewUint(0), + wantFee0: u256.NewUint(1000), + wantFee1: u256.NewUint(2000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got0, got1 := tt.pool.saveProtocolFees(tt.amount0, tt.amount1) + + uassert.Equal(t, got0.ToString(), tt.want0.ToString()) + uassert.Equal(t, got1.ToString(), tt.want1.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token0.ToString(), tt.wantFee0.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token1.ToString(), tt.wantFee1.ToString()) + }) + } +} + +func TestShouldContinueSwap(t *testing.T) { + tests := []struct { + name string + state SwapState + sqrtPriceLimitX96 *u256.Uint + expected bool + }{ + { + name: "Should continue - amount remaining and price not at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: true, + }, + { + name: "Should stop - no amount remaining", + state: SwapState{ + amountSpecifiedRemaining: i256.Zero(), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + { + name: "Should stop - price at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("900000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) + uassert.Equal(t, tt.expected, result) + }) + } +} + +func TestUpdateAmounts(t *testing.T) { + tests := []struct { + name string + step StepComputations + state SwapState + exactInput bool + expectedState SwapState + }{ + { + name: "Exact input update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + amountCalculated: i256.Zero(), + }, + exactInput: true, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) + amountCalculated: i256.MustFromDecimal("-97"), + }, + }, + { + name: "Exact output update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), + amountCalculated: i256.Zero(), + }, + exactInput: false, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 + amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := updateAmounts(tt.step, tt.state, tt.exactInput) + + uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) + uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) + }) + } +} + +func TestComputeSwap(t *testing.T) { + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, // 0.3% + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: avl.NewTree(), + ticks: avl.NewTree(), + positions: avl.NewTree(), + } + + wordPos, _ := tickBitmapPosition(0) + mockPool.setTickBitmap(wordPos, u256.NewUint(1)) + + t.Run("basic swap", func(t *testing.T) { + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPool.slot0.sqrtPriceX96, + tick: mockPool.slot0.tick, + feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPool.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPool.liquidity, + }, + } + + result, err := computeSwap(mockPool, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if result.Amount0.IsZero() { + t.Error("expected non-zero amount0") + } + if result.Amount1.IsZero() { + t.Error("expected non-zero amount1") + } + }) + + t.Run("swap with zero liquidity", func(t *testing.T) { + mockPoolZeroLiq := *mockPool + mockPoolZeroLiq.liquidity = u256.Zero() + + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, + tick: mockPoolZeroLiq.slot0.tick, + feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPoolZeroLiq.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPoolZeroLiq.liquidity, + }, + } + + result, err := computeSwap(&mockPoolZeroLiq, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if !result.Amount0.IsZero() || !result.Amount1.IsZero() { + t.Error("expected zero amounts for zero liquidity") + } + }) +} + +func TestSwap_Failures(t *testing.T) { + const addr = consts.ROUTER_ADDR + + tests := []struct { + name string + setupFn func(t *testing.T) + token0Path string + token1Path string + fee uint32 + recipient std.Address + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + payer std.Address + expectedAmount0 string + expectedAmount1 string + expectError bool + }{ + { + name: "locked pool", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + pool.slot0.unlocked = false + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: addr, + zeroForOne: true, + amountSpecified: "100", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: addr, + expectError: true, + }, + { + name: "zero amount", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: alice, + zeroForOne: true, + amountSpecified: "0", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: alice, + expectError: true, + }, + { + name: "zero liquidity", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + pool.liquidity = u256.Zero() + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: alice, + zeroForOne: true, + amountSpecified: "100", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: alice, + expectedAmount0: "0", + expectedAmount1: "0", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + std.TestSetOrigCaller(tt.payer) + + if tt.expectError { + defer func() { + if r := recover(); r == nil { + t.Errorf("error should be occurred but not occurred") + } + }() + } + + amount0, amount1 := Swap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.recipient, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + tt.payer, + ) + + if !tt.expectError { + uassert.Equal(t, amount0, tt.expectedAmount0) + uassert.Equal(t, amount1, tt.expectedAmount1) + } + }) + } +} + +func TestDrySwap_Failures(t *testing.T) { + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + balances: Balances{ + token0: u256.MustFromDecimal("1000000000"), + token1: u256.MustFromDecimal("1000000000"), + }, + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: avl.NewTree(), + ticks: avl.NewTree(), + positions: avl.NewTree(), + } + + originalGetPool := GetPool + defer func() { + GetPool = originalGetPool + }() + GetPool = func(token0Path, token1Path string, fee uint32) *Pool { + return mockPool + } + + tests := []struct { + name string + token0Path string + token1Path string + fee uint32 + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + expectAmount0 string + expectAmount1 string + expectSuccess bool + }{ + { + name: "zero amount token0 to token1", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: true, + amountSpecified: "0", + sqrtPriceLimitX96: "900000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + { + name: "insufficient balance", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: false, + amountSpecified: "2000000000", + sqrtPriceLimitX96: "1100000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + { + name: "insufficient balance token1 to token0", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: false, + amountSpecified: "3000000000", + sqrtPriceLimitX96: "1100000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + amount0, amount1, success := DrySwap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + ) + + uassert.Equal(t, success, tt.expectSuccess) + uassert.Equal(t, amount0, tt.expectAmount0) + uassert.Equal(t, amount1, tt.expectAmount1) + }) + } +} + +func TestSwapAndDrySwapComparison(t *testing.T) { + const addr = consts.ROUTER_ADDR + + tests := []struct { + name string + setupFn func(t *testing.T) + token0Path string + token1Path string + fee uint32 + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + }{ + { + name: "normal swap", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + zeroForOne: false, + amountSpecified: "100", + sqrtPriceLimitX96: maxSqrtPriceLimitX96, + }, + { + name: "swap - request to swap amount over total liquidty", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + zeroForOne: false, + amountSpecified: "2000000000", + sqrtPriceLimitX96: maxSqrtPriceLimitX96, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + dryAmount0, dryAmount1, drySuccess := DrySwap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + ) + + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + tt.token0Path, + tt.token1Path, + tt.fee, + addr, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + addr, + ) + + if !drySuccess { + t.Error("DrySwap failed but actual Swap succeeded") + } + + uassert.NotEqual(t, dryAmount0, "0", "amount0 should not be zero") + uassert.NotEqual(t, dryAmount1, "0", "amount1 should not be zero") + uassert.NotEqual(t, actualAmount0, "0", "amount0 should not be zero") + uassert.NotEqual(t, actualAmount1, "0", "amount1 should not be zero") + + uassert.Equal(t, dryAmount0, actualAmount0, + "Amount0 mismatch between DrySwap and actual Swap") + uassert.Equal(t, dryAmount1, actualAmount1, + "Amount1 mismatch between DrySwap and actual Swap") + }) + } +} + +func TestSwapAndDrySwapComparison_amount_zero(t *testing.T) { + const addr = consts.ROUTER_ADDR + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) + shouldPanic bool + expected string + }{ + { + name: "zero amount swap - zeroForOne = false", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + action: func(t *testing.T) { + dryAmount0, dryAmount1, drySuccess := DrySwap( + wugnotPath, + gnsPath, + fee3000, + false, + "0", + maxSqrtPriceLimitX96, + ) + uassert.Equal(t, "0", dryAmount0) + uassert.Equal(t, "0", dryAmount1) + uassert.Equal(t, false, drySuccess) + }, + shouldPanic: false, + expected: "[GNOSWAP-POOL-014] invalid swap amount || amountSpecified == 0", + }, + { + name: "zero amount swap - zeroForOne = true", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + action: func(t *testing.T) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + addr, + true, + "0", + maxSqrtPriceLimitX96, + addr, + ) + }, + shouldPanic: true, + expected: "[GNOSWAP-POOL-014] invalid swap amount || amountSpecified == 0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tt.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tt.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + case error: + if r.(error).Error() != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r.(error).Error(), tt.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + } + }() + + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + tt.action(t) + } + }) + } +} + +func TestSwap_amount_over_liquidity(t *testing.T) { + const addr = consts.ROUTER_ADDR + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) (string, string) + shouldPanic bool + expected []string + }{ + { + name: "amount over liquidity - zeroForOne = false, token0:20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + alice, + false, + "20000", + consts.MAX_PRICE, + addr, + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-1989", "2404"}, + }, + { + name: "amount over liquidity - zeroForOne = false, token1:-20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + alice, + false, + "-20000", + consts.MAX_PRICE, + addr, + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-1989", "2404"}, + }, + { + name: "amount over liquidity - zeroForOne = true, token0:20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + alice, + true, + "20000", + consts.MIN_PRICE, + addr, + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"1045", "-990"}, + }, + { + name: "amount over liquidity - zeroForOne = true, token1:-20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenFaucet(t, wugnotPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + alice, + true, + "-20000", + consts.MIN_PRICE, + addr, + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"1045", "-990"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + amount0, amount1 := tt.action(t) + uassert.Equal(t, tt.expected[0], amount0) + uassert.Equal(t, tt.expected[1], amount1) + } + }) + } +} + +func TestSwap_EXACTIN_OUT(t *testing.T) { + const addr = consts.ROUTER_ADDR + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) (string, string) + shouldPanic bool + expected []string + }{ + { + name: "EXACT IN - zeroForOne = false, token1:200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + alice, + false, + "200", + consts.MAX_PRICE, + addr, + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-195", "200"}, + }, + { + name: "EXACT OUT - zeroForOne = false, token0:-200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + alice, + false, + "-200", + consts.MAX_PRICE, + addr, + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-200", "208"}, + }, + { + name: "EXACT IN - zeroForOne = true, token0:200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + alice, + true, + "200", + consts.MIN_PRICE, + addr, + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"200", "-195"}, + }, + { + name: "EXACT OUT - zeroForOne = true, token1:-200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, admin) + TokenFaucet(t, gnsPath, addr) + TokenFaucet(t, wugnotPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + alice, + true, + "-200", + consts.MIN_PRICE, + addr, + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"208", "-200"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + amount0, amount1 := tt.action(t) + uassert.Equal(t, tt.expected[0], amount0) + uassert.Equal(t, tt.expected[1], amount1) + } + }) + } +} diff --git a/contract/r/gnoswap/pool/tick.gno b/contract/r/gnoswap/pool/tick.gno new file mode 100644 index 000000000..627255839 --- /dev/null +++ b/contract/r/gnoswap/pool/tick.gno @@ -0,0 +1,407 @@ +package pool + +import ( + "strconv" + + "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" +) + +// calculateMaxLiquidityPerTick calculates the maximum liquidity +// per tick for a given tick spacing. +func calculateMaxLiquidityPerTick(tickSpacing int32) *u256.Uint { + minTick := (consts.MIN_TICK / tickSpacing) * tickSpacing + maxTick := (consts.MAX_TICK / tickSpacing) * tickSpacing + numTicks := uint64((maxTick-minTick)/tickSpacing) + 1 + + return new(u256.Uint).Div(u256.MustFromDecimal(consts.MAX_UINT128), u256.NewUint(numTicks)) +} + +// getFeeGrowthBelowX128 calculates the fee growth below a specified tick. +// +// This function computes the fee growth for token 0 and token 1 below a given tick (`tickLower`) +// relative to the current tick (`tickCurrent`). The fee growth values are adjusted based on whether +// the `tickCurrent` is above or below the `tickLower`. +// +// Parameters: +// - tickLower: int32, the lower tick boundary for fee calculation. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// - lowerTick: TickInfo, the fee growth and liquidity details for the lower tick. +// +// Returns: +// - *u256.Uint: Fee growth below `tickLower` for token 0. +// - *u256.Uint: Fee growth below `tickLower` for token 1. +// +// Workflow: +// 1. If `tickCurrent` is greater than or equal to `tickLower`: +// - Return the `feeGrowthOutside0X128` and `feeGrowthOutside1X128` values of the `lowerTick`. +// 2. If `tickCurrent` is below `tickLower`: +// - Compute the fee growth below the lower tick by subtracting `feeGrowthOutside` values +// from the global fee growth values (`feeGrowthGlobal0X128` and `feeGrowthGlobal1X128`). +// 3. Return the calculated fee growth values for both tokens. +// +// Behavior: +// - If `tickCurrent >= tickLower`, the fee growth outside the lower tick is returned as-is. +// - If `tickCurrent < tickLower`, the fee growth is calculated as: +// feeGrowthBelow = feeGrowthGlobal - feeGrowthOutside +// +// Example: +// +// feeGrowth0, feeGrowth1 := getFeeGrowthBelowX128( +// 100, 150, globalFeeGrowth0, globalFeeGrowth1, lowerTickInfo, +// ) +// fmt.Println("Fee Growth Below:", feeGrowth0, feeGrowth1) +func getFeeGrowthBelowX128( + tickLower, tickCurrent int32, + feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, + lowerTick TickInfo, +) (*u256.Uint, *u256.Uint) { + if tickCurrent >= tickLower { + return lowerTick.feeGrowthOutside0X128, lowerTick.feeGrowthOutside1X128 + } + + feeGrowthBelow0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) + feeGrowthBelow1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) + + return feeGrowthBelow0X128, feeGrowthBelow1X128 +} + +// getFeeGrowthAboveX128 calculates the fee growth above a specified tick. +// +// This function computes the fee growth for token 0 and token 1 above a given tick (`tickUpper`) +// relative to the current tick (`tickCurrent`). The fee growth values are adjusted based on whether +// the `tickCurrent` is above or below the `tickUpper`. +// +// Parameters: +// - tickUpper: int32, the upper tick boundary for fee calculation. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// - upperTick: TickInfo, the fee growth and liquidity details for the upper tick. +// +// Returns: +// - *u256.Uint: Fee growth above `tickUpper` for token 0. +// - *u256.Uint: Fee growth above `tickUpper` for token 1. +// +// Workflow: +// 1. If `tickCurrent` is less than `tickUpper`: +// - Return the `feeGrowthOutside0X128` and `feeGrowthOutside1X128` values of the `upperTick`. +// 2. If `tickCurrent` is greater than or equal to `tickUpper`: +// - Compute the fee growth above the upper tick by subtracting `feeGrowthOutside` values +// from the global fee growth values (`feeGrowthGlobal0X128` and `feeGrowthGlobal1X128`). +// 3. Return the calculated fee growth values for both tokens. +// +// Behavior: +// - If `tickCurrent < tickUpper`, the fee growth outside the upper tick is returned as-is. +// - If `tickCurrent >= tickUpper`, the fee growth is calculated as: +// feeGrowthAbove = feeGrowthGlobal - feeGrowthOutside +// +// Example: +// +// feeGrowth0, feeGrowth1 := getFeeGrowthAboveX128( +// 200, 150, globalFeeGrowth0, globalFeeGrowth1, upperTickInfo, +// ) +// fmt.Println("Fee Growth Above:", feeGrowth0, feeGrowth1) +func getFeeGrowthAboveX128( + tickUpper, tickCurrent int32, + feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, + upperTick TickInfo, +) (*u256.Uint, *u256.Uint) { + if tickCurrent < tickUpper { + return upperTick.feeGrowthOutside0X128, upperTick.feeGrowthOutside1X128 + } + + feeGrowthAbove0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) + feeGrowthAbove1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) + + return feeGrowthAbove0X128, feeGrowthAbove1X128 +} + +// getFeeGrowthInside calculates the fee growth within a specified tick range. +// +// This function computes the accumulated fee growth for token 0 and token 1 inside a given tick range +// (`tickLower` to `tickUpper`) relative to the current tick position (`tickCurrent`). It isolates the fee +// growth within the range by subtracting the fee growth below the lower tick and above the upper tick +// from the global fee growth. +// +// Parameters: +// - tickLower: int32, the lower tick boundary of the range. +// - tickUpper: int32, the upper tick boundary of the range. +// - tickCurrent: int32, the current tick index. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth for token 0 in X128 precision. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth for token 1 in X128 precision. +// +// Returns: +// - *u256.Uint: Fee growth inside the tick range for token 0. +// - *u256.Uint: Fee growth inside the tick range for token 1. +// +// Workflow: +// 1. Retrieve the tick information (`lower` and `upper`) for the lower and upper tick boundaries +// using `p.getTick`. +// 2. Calculate the fee growth below the lower tick using `getFeeGrowthBelowX128`. +// 3. Calculate the fee growth above the upper tick using `getFeeGrowthAboveX128`. +// 4. Subtract the fee growth below and above the range from the global fee growth values: +// feeGrowthInside = feeGrowthGlobal - feeGrowthBelow - feeGrowthAbove +// 5. Return the computed fee growth values for token 0 and token 1 within the range. +// +// Behavior: +// - The fee growth is isolated within the range `[tickLower, tickUpper]`. +// - The function ensures the calculations accurately consider the tick boundaries and the current tick position. +// +// Example: +// +// feeGrowth0, feeGrowth1 := pool.getFeeGrowthInside( +// 100, 200, 150, globalFeeGrowth0, globalFeeGrowth1, +// ) +// fmt.Println("Fee Growth Inside (Token 0):", feeGrowth0) +// fmt.Println("Fee Growth Inside (Token 1):", feeGrowth1) +func (p *Pool) getFeeGrowthInside( + tickLower int32, + tickUpper int32, + tickCurrent int32, + feeGrowthGlobal0X128 *u256.Uint, + feeGrowthGlobal1X128 *u256.Uint, +) (*u256.Uint, *u256.Uint) { + lower := p.getTick(tickLower) + upper := p.getTick(tickUpper) + + feeGrowthBelow0X128, feeGrowthBelow1X128 := getFeeGrowthBelowX128(tickLower, tickCurrent, feeGrowthGlobal0X128, feeGrowthGlobal1X128, lower) + feeGrowthAbove0X128, feeGrowthAbove1X128 := getFeeGrowthAboveX128(tickUpper, tickCurrent, feeGrowthGlobal0X128, feeGrowthGlobal1X128, upper) + + feeGrowthInside0X128 := new(u256.Uint).Sub(new(u256.Uint).Sub(feeGrowthGlobal0X128, feeGrowthBelow0X128), feeGrowthAbove0X128) + feeGrowthInside1X128 := new(u256.Uint).Sub(new(u256.Uint).Sub(feeGrowthGlobal1X128, feeGrowthBelow1X128), feeGrowthAbove1X128) + + return feeGrowthInside0X128, feeGrowthInside1X128 +} + +// tickUpdate updates the state of a specific tick. +// +// This function applies a given liquidity change (liquidityDelta) to the specified tick, updates +// the fee growth values if necessary, and adjusts the net liquidity based on whether the tick +// is an upper or lower boundary. It also verifies that the total liquidity does not exceed the +// maximum allowed value and ensures the net liquidity stays within the valid int128 range. +// +// Parameters: +// - tick: int32, the index of the tick to update. +// - tickCurrent: int32, the current active tick index. +// - liquidityDelta: *i256.Int, the amount of liquidity to add or remove. +// - feeGrowthGlobal0X128: *u256.Uint, the global fee growth value for token 0. +// - feeGrowthGlobal1X128: *u256.Uint, the global fee growth value for token 1. +// - upper: bool, indicates if this is the upper boundary (true for upper, false for lower). +// - maxLiquidity: *u256.Uint, the maximum allowed liquidity. +// +// Returns: +// - flipped: bool, indicates if the tick's initialization state has changed. +// (e.g., liquidity transitioning from zero to non-zero, or vice versa) +// +// Workflow: +// 1. Nil input values are replaced with zero. +// 2. The function retrieves the tick information for the specified tick index. +// 3. Applies the liquidityDelta to compute the new total liquidity (liquidityGross). +// - If the total liquidity exceeds the maximum allowed value, the function panics. +// 4. Checks whether the tick's initialized state has changed and sets the `flipped` flag. +// 5. If the tick was previously uninitialized and its index is less than or equal to the current tick, +// the fee growth values are initialized to the current global values. +// 6. Updates the tick's net liquidity: +// - For an upper boundary, it subtracts liquidityDelta. +// - For a lower boundary, it adds liquidityDelta. +// - Ensures the net liquidity remains within the int128 range using `checkOverFlowInt128`. +// 7. Updates the tick's state with the new values. +// 8. Returns whether the tick's initialized state has flipped. +// +// Panic Conditions: +// - The total liquidity (liquidityGross) exceeds the maximum allowed liquidity (maxLiquidity). +// - The net liquidity (liquidityNet) exceeds the int128 range. +// +// Example: +// +// flipped := pool.tickUpdate(10, 5, liquidityDelta, feeGrowth0, feeGrowth1, true, maxLiquidity) +// fmt.Println("Tick flipped:", flipped) +func (p *Pool) tickUpdate( + tick int32, + tickCurrent int32, + liquidityDelta *i256.Int, + feeGrowthGlobal0X128 *u256.Uint, + feeGrowthGlobal1X128 *u256.Uint, + upper bool, + maxLiquidity *u256.Uint, +) (flipped bool) { + liquidityDelta = liquidityDelta.NilToZero() + feeGrowthGlobal0X128 = feeGrowthGlobal0X128.NilToZero() + feeGrowthGlobal1X128 = feeGrowthGlobal1X128.NilToZero() + + tickInfo := p.getTick(tick) + + liquidityGrossBefore := tickInfo.liquidityGross.Clone() + liquidityGrossAfter := liquidityMathAddDelta(liquidityGrossBefore, liquidityDelta) + + if !(liquidityGrossAfter.Lte(maxLiquidity)) { + panic(addDetailToError( + errLiquidityCalculation, + ufmt.Sprintf("liquidityGrossAfter(%s) overflows maxLiquidity(%s)", liquidityGrossAfter.ToString(), maxLiquidity.ToString()), + )) + } + + flipped = (liquidityGrossAfter.IsZero()) != (liquidityGrossBefore.IsZero()) + + if liquidityGrossBefore.IsZero() { + if tick <= tickCurrent { + tickInfo.feeGrowthOutside0X128 = feeGrowthGlobal0X128.Clone() + tickInfo.feeGrowthOutside1X128 = feeGrowthGlobal1X128.Clone() + } + tickInfo.initialized = true + } + + tickInfo.liquidityGross = liquidityGrossAfter.Clone() + + if upper { + tickInfo.liquidityNet = i256.Zero().Sub(tickInfo.liquidityNet, liquidityDelta) + checkOverFlowInt128(tickInfo.liquidityNet) + } else { + tickInfo.liquidityNet = i256.Zero().Add(tickInfo.liquidityNet, liquidityDelta) + checkOverFlowInt128(tickInfo.liquidityNet) + } + + p.setTick(tick, tickInfo) + + return flipped +} + +// tickCross updates a tick's state when it is crossed and returns the liquidity net. +func (p *Pool) tickCross( + tick int32, + feeGrowthGlobal0X128 *u256.Uint, + feeGrowthGlobal1X128 *u256.Uint, +) *i256.Int { + thisTick := p.getTick(tick) + + thisTick.feeGrowthOutside0X128 = new(u256.Uint).Sub(feeGrowthGlobal0X128, thisTick.feeGrowthOutside0X128) + thisTick.feeGrowthOutside1X128 = new(u256.Uint).Sub(feeGrowthGlobal1X128, thisTick.feeGrowthOutside1X128) + + p.setTick(tick, thisTick) + + return thisTick.liquidityNet.Clone() +} + +// setTick updates the tick data for the specified tick index in the pool. +func (p *Pool) setTick(tick int32, newTickInfo TickInfo) { + tickStr := strconv.Itoa(int(tick)) + p.ticks.Set(tickStr, newTickInfo) +} + +// deleteTick deletes the tick data for the specified tick index in the pool. +func (p *Pool) deleteTick(tick int32) { + tickStr := strconv.Itoa(int(tick)) + p.ticks.Remove(tickStr) +} + +// getTick retrieves the TickInfo associated with the specified tick index from the pool. +// If the TickInfo contains any nil fields, they are replaced with zero values using valueOrZero. +// +// Parameters: +// - tick: The tick index (int32) for which the TickInfo is to be retrieved. +// +// Behavior: +// - Retrieves the TickInfo for the given tick from the pool's tick map. +// - Ensures that all fields of TickInfo are non-nil by calling valueOrZero, which replaces nil values with zero. +// - Returns the updated TickInfo. +// +// Returns: +// - TickInfo: The tick data with all fields guaranteed to have valid values (nil fields are set to zero). +// +// Use Case: +// This function ensures the retrieved tick data is always valid and safe for further operations, +// such as calculations or updates, by sanitizing nil fields in the TickInfo structure. +func (p *Pool) getTick(tick int32) TickInfo { + tickStr := strconv.Itoa(int(tick)) + iTickInfo, exist := p.ticks.Get(tickStr) + if !exist { + tickInfo := TickInfo{} + tickInfo.valueOrZero() + return tickInfo + } + + return iTickInfo.(TickInfo) +} + +// GetTickLiquidityGross returns the gross liquidity for the specified tick. +func (p *Pool) GetTickLiquidityGross(tick int32) *u256.Uint { + return p.mustGetTick(tick).liquidityGross +} + +// GetTickLiquidityNet returns the net liquidity for the specified tick. +func (p *Pool) GetTickLiquidityNet(tick int32) *i256.Int { + return p.mustGetTick(tick).liquidityNet +} + +// GetTickFeeGrowthOutside0X128 returns the fee growth outside the tick for token 0. +func (p *Pool) GetTickFeeGrowthOutside0X128(tick int32) *u256.Uint { + return p.mustGetTick(tick).feeGrowthOutside0X128 +} + +// GetTickFeeGrowthOutside1X128 returns the fee growth outside the tick for token 1. +func (p *Pool) GetTickFeeGrowthOutside1X128(tick int32) *u256.Uint { + return p.mustGetTick(tick).feeGrowthOutside1X128 +} + +// GetTickCumulativeOutside returns the cumulative liquidity outside the tick. +func (p *Pool) GetTickCumulativeOutside(tick int32) int64 { + return p.mustGetTick(tick).tickCumulativeOutside +} + +// GetTickSecondsPerLiquidityOutsideX128 returns the seconds per liquidity outside the tick. +func (p *Pool) GetTickSecondsPerLiquidityOutsideX128(tick int32) *u256.Uint { + return p.mustGetTick(tick).secondsPerLiquidityOutsideX128 +} + +// GetTickSecondsOutside returns the seconds outside the tick. +func (p *Pool) GetTickSecondsOutside(tick int32) uint32 { + return p.mustGetTick(tick).secondsOutside +} + +// GetTickInitialized returns whether the tick is initialized. +func (p *Pool) GetTickInitialized(tick int32) bool { + return p.mustGetTick(tick).initialized +} + +// mustGetTick retrieves the TickInfo for a specific tick, panicking if the tick does not exist. +// +// This function ensures that the requested tick data exists in the pool's tick mapping. +// If the tick does not exist, it panics with an appropriate error message. +// +// Parameters: +// - tick: int32, the index of the tick to retrieve. +// +// Returns: +// - TickInfo: The information associated with the specified tick. +// +// Behavior: +// - Checks if the tick exists in the pool's tick mapping (`p.ticks`). +// - If the tick exists, it returns the corresponding `TickInfo`. +// - If the tick does not exist, the function panics with a descriptive error. +// +// Panic Conditions: +// - The specified tick does not exist in the pool's mapping. +// +// Example: +// +// tickInfo := pool.mustGetTick(10) +// fmt.Println("Tick Info:", tickInfo) +func (p *Pool) mustGetTick(tick int32) TickInfo { + tickStr := strconv.Itoa(int(tick)) + iTickInfo, exist := p.ticks.Get(tickStr) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("tick(%d) does not exist", tick), + )) + } + + return iTickInfo.(TickInfo) +} diff --git a/contract/r/gnoswap/pool/tick_bitmap.gno b/contract/r/gnoswap/pool/tick_bitmap.gno new file mode 100644 index 000000000..46d4c1ec6 --- /dev/null +++ b/contract/r/gnoswap/pool/tick_bitmap.gno @@ -0,0 +1,174 @@ +package pool + +import ( + "strconv" + + "gno.land/p/demo/ufmt" + plp "gno.land/p/gnoswap/gnsmath" + + u256 "gno.land/p/gnoswap/uint256" +) + +// tickBitmapPosition calculates the word and bit position for a given tick +func tickBitmapPosition(tick int32) (int16, uint8) { + wordPos := int16(tick >> 8) // tick / 256 + bitPos := uint8(tick % 256) + return wordPos, bitPos +} + +// tickBitmapFlipTick flips the state of a tick in the tick bitmap. +// +// This function toggles the "initialized" state of a tick in the tick bitmap. +// It ensures that the tick aligns with the specified tick spacing and then +// flips the corresponding bit in the bitmap representation. +// +// Parameters: +// - tick: int32, the tick index to toggle. +// - tickSpacing: int32, the spacing between valid ticks. +// The tick must align with this spacing. +// +// Workflow: +// 1. Validates that the `tick` aligns with `tickSpacing` using `checkTickSpacing`. +// 2. Computes the position of the bit in the tick bitmap: +// - `wordPos`: Determines which word in the bitmap contains the bit. +// - `bitPos`: Identifies the position of the bit within the word. +// 3. Creates a bitmask using `Lsh` (Left Shift) to target the bit at `bitPos`. +// 4. Toggles (flips) the bit using XOR with the current value of the tick bitmap. +// 5. Updates the tick bitmap with the modified word. +// +// Behavior: +// - If the bit is `0` (uninitialized), it will be flipped to `1` (initialized). +// - If the bit is `1` (initialized), it will be flipped to `0` (uninitialized). +// +// Example: +// +// pool.tickBitmapFlipTick(120, 60) +// // This flips the bit for tick 120 with a tick spacing of 60. +// +// Notes: +// - The `tick` must be divisible by `tickSpacing`. If not, the function will panic. +func (p *Pool) tickBitmapFlipTick( + tick int32, + tickSpacing int32, +) { + checkTickSpacing(tick, tickSpacing) + wordPos, bitPos := tickBitmapPosition(tick / tickSpacing) + + mask := new(u256.Uint).Lsh(u256.One(), uint(bitPos)) + p.setTickBitmap(wordPos, new(u256.Uint).Xor(p.getTickBitmap(wordPos), mask)) +} + +// tickBitmapNextInitializedTickWithInOneWord finds the next initialized tick within +// one word of the bitmap. +func (p *Pool) tickBitmapNextInitializedTickWithInOneWord( + tick int32, + tickSpacing int32, + lte bool, +) (int32, bool) { + compress := tick / tickSpacing + if tick < 0 && tick%tickSpacing != 0 { + compress-- + } + + wordPos, bitPos := getWordAndBitPos(compress, lte) + mask := getMaskBit(uint(bitPos), lte) + masked := new(u256.Uint).And(p.getTickBitmap(wordPos), mask) + initialized := !(masked.IsZero()) + + nextTick := getNextTick(lte, initialized, compress, bitPos, tickSpacing, masked) + return nextTick, initialized +} + +// getTickBitmap gets the tick bitmap for the given word position +// if the tick bitmap is not initialized, initialize it to zero +func (p *Pool) getTickBitmap(wordPos int16) *u256.Uint { + wordPosStr := strconv.Itoa(int(wordPos)) + + if !p.tickBitmaps.Has(wordPosStr) { + p.initTickBitmap(wordPos) + } + + iU256, exist := p.tickBitmaps.Get(wordPosStr) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("tickBitmap(%d) does not exist", wordPos), + )) + } + + return iU256.(*u256.Uint) +} + +// setTickBitmap sets the tick bitmap for the given word position +func (p *Pool) setTickBitmap(wordPos int16, tickBitmap *u256.Uint) { + wordPosStr := strconv.Itoa(int(wordPos)) + p.tickBitmaps.Set(wordPosStr, tickBitmap) +} + +// initTickBitmap initializes the tick bitmap for the given word position +func (p *Pool) initTickBitmap(wordPos int16) { + p.setTickBitmap(wordPos, u256.Zero()) +} + +// getWordAndBitPos gets tick's wordPos and bitPos depending on the swap direction +func getWordAndBitPos(tick int32, lte bool) (int16, uint8) { + if lte { + return tickBitmapPosition(tick) + } + + tick++ + return tickBitmapPosition(tick) +} + +// bMap is a map that maps boolean values to uint values. +// true maps to 1, and false maps to 0. +var bMap = map[bool]uint{ + true: 1, + false: 0, +} + +// getMaskBit generates a mask based on the provided bit position (bitPos) and a boolean flag (lte). +// The function constructs a bitmask with a shift depending on the bit position and the boolean value. +// It either returns the mask or its negation, based on the value of 'lte' (swap direction). +func getMaskBit(bitPos uint, lte bool) *u256.Uint { + // Shift the number 1 to the left by (bitPos + bMap[lte]) positions. + // If lte is true, the shift will be bitPos + 1; if false, it will be just bitPos. + shifted := new(u256.Uint).Lsh(u256.One(), bitPos+bMap[lte]) + + // Subtract 1 from the shifted value to create a mask. + mask := new(u256.Uint).Sub(shifted, u256.One()) + + // If lte is false, return the negation of the mask. + if !lte { + return new(u256.Uint).Not(mask) + } + // Otherwise, return the mask itself. + return mask +} + +// getNextTick gets the next tick depending on the initialized state and the swap direction +func getNextTick(lte, initialized bool, compress int32, bitPos uint8, tickSpacing int32, masked *u256.Uint) int32 { + if initialized { + return getTickIfInitialized(compress, tickSpacing, bitPos, masked, lte) + } + + return getTickIfNotInitialized(compress, tickSpacing, bitPos, lte) +} + +// getTickIfInitialized gets the next tick if the tick bitmap is initialized +func getTickIfInitialized(compress, tickSpacing int32, bitPos uint8, masked *u256.Uint, lte bool) int32 { + if lte { + return (compress - int32(bitPos-plp.BitMathMostSignificantBit(masked))) * tickSpacing + } + + return (compress + 1 + int32(plp.BitMathLeastSignificantBit(masked)-bitPos)) * tickSpacing +} + +// getTickIfNotInitialized gets the next tick if the tick bitmap is not initialized +func getTickIfNotInitialized(compress, tickSpacing int32, bitPos uint8, lte bool) int32 { + if lte { + return (compress - int32(bitPos)) * tickSpacing + } + + return (compress + 1 + int32(255-bitPos)) * tickSpacing +} diff --git a/contract/r/gnoswap/pool/tick_bitmap_test.gno b/contract/r/gnoswap/pool/tick_bitmap_test.gno new file mode 100644 index 000000000..05d52a001 --- /dev/null +++ b/contract/r/gnoswap/pool/tick_bitmap_test.gno @@ -0,0 +1,161 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/avl" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestTickBitmapPosition(t *testing.T) { + tests := []struct { + name string + tick int32 + wantWordPos int16 + wantBitPos uint8 + }{ + { + name: "positive tick", + tick: 300, + wantWordPos: 1, // 300 >> 8 = 1 + wantBitPos: 44, // 300 % 256 = 44 + }, + { + name: "negative tick", + tick: -300, + wantWordPos: -2, // -300 >> 8 = -2 + wantBitPos: 212, // -300 % 256 = 212 + }, + { + name: "zero tick", + tick: 0, + wantWordPos: 0, + wantBitPos: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wordPos, bitPos := tickBitmapPosition(tt.tick) + if wordPos != tt.wantWordPos { + t.Errorf("wordPos = %v, want %v", wordPos, tt.wantWordPos) + } + if bitPos != tt.wantBitPos { + t.Errorf("bitPos = %v, want %v", bitPos, tt.wantBitPos) + } + }) + } +} + +func TestTickBitmapFlipTick(t *testing.T) { + tests := []struct { + name string + tick int32 + tickSpacing int32 + shouldPanic bool + }{ + { + name: "valid positive tick and spacing", + tick: 100, + tickSpacing: 20, + shouldPanic: false, + }, + { + name: "valid negative tick and spacing", + tick: -300, + tickSpacing: 20, + shouldPanic: false, + }, + { + name: "invalid tick and spacing", + tick: 101, + tickSpacing: 20, + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool := &Pool{ + tickBitmaps: avl.NewTree(), + } + + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but got none") + } + }() + } + + pool.tickBitmapFlipTick(tt.tick, tt.tickSpacing) + + if !tt.shouldPanic { + wordPos, bitPos := tickBitmapPosition(tt.tick / tt.tickSpacing) + expected := new(u256.Uint).Lsh(u256.NewUint(1), uint(bitPos)) + if pool.getTickBitmap(wordPos).Cmp(expected) != 0 { + t.Errorf("bitmap not set correctly") + } + } + }) + } +} + +func TestTickBitmapNextInitializedTickWithInOneWord(t *testing.T) { + tests := []struct { + name string + tick int32 + tickSpacing int32 + lte bool + setupBitmap func(*Pool) + wantTick int32 + wantInit bool + }{ + { + name: "search lte with initialized tick", + tick: 200, + tickSpacing: 20, + lte: true, + setupBitmap: func(p *Pool) { + p.tickBitmapFlipTick(180, 20) // Initialize a tick at 180 + }, + wantTick: 180, + wantInit: true, + }, + { + name: "search gt with initialized tick", + tick: 100, + tickSpacing: 20, + lte: false, + setupBitmap: func(p *Pool) { + p.tickBitmapFlipTick(140, 20) // Initialize a tick at 140 + }, + wantTick: 140, + wantInit: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool := &Pool{ + tickBitmaps: avl.NewTree(), + } + if tt.setupBitmap != nil { + tt.setupBitmap(pool) + } + + gotTick, gotInit := pool.tickBitmapNextInitializedTickWithInOneWord( + tt.tick, + tt.tickSpacing, + tt.lte, + ) + + if gotTick != tt.wantTick { + t.Errorf("tick = %v, want %v", gotTick, tt.wantTick) + } + if gotInit != tt.wantInit { + t.Errorf("initialized = %v, want %v", gotInit, tt.wantInit) + } + }) + } +} diff --git a/contract/r/gnoswap/pool/tick_test.gno b/contract/r/gnoswap/pool/tick_test.gno new file mode 100644 index 000000000..d7a2aee05 --- /dev/null +++ b/contract/r/gnoswap/pool/tick_test.gno @@ -0,0 +1,730 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/avl" + "gno.land/p/demo/uassert" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestcalculateMaxLiquidityPerTick(t *testing.T) { + tests := []struct { + name string + tickSpacing int32 + want string // expected result in string format + }{ + { + name: "tick spacing 1 (for 0.01% pool)", + tickSpacing: 1, + want: "191757530477355301479181766273477", + }, + { + name: "tick spacing 10 (for 0.05% pool)", + tickSpacing: 10, + want: "1917569901783203986719870431555990", + }, + { + name: "tick spacing 60 (for 0.3% pool)", + tickSpacing: 60, + want: "11505743598341114571880798222544994", + }, + { + name: "tick spacing 200 (for 1% pool)", + tickSpacing: 200, + want: "38350317471085141830651933667504588", + }, + { + name: "entire range", + tickSpacing: 887272, + want: "113427455640312821154458202477256070485", + }, + { + name: "custom tick spacing", + tickSpacing: 2302, + want: "441351967472034323558203122479595605", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := calculateMaxLiquidityPerTick(tt.tickSpacing) + if got.ToString() != tt.want { + t.Errorf("calculateMaxLiquidityPerTick() = %v, want %v", got.ToString(), tt.want) + } + }) + } +} + +func TestCalculateFeeGrowthInside(t *testing.T) { + // Create a mock pool + pool := &Pool{ + ticks: avl.NewTree(), + } + + // Setup test ticks + pool.ticks.Set("0", TickInfo{ + liquidityGross: u256.NewUint(1000), + liquidityNet: i256.NewInt(100), + feeGrowthOutside0X128: u256.NewUint(5), + feeGrowthOutside1X128: u256.NewUint(7), + initialized: true, + }) + pool.ticks.Set("100", TickInfo{ + liquidityGross: u256.NewUint(2000), + liquidityNet: i256.NewInt(-100), + feeGrowthOutside0X128: u256.NewUint(10), + feeGrowthOutside1X128: u256.NewUint(15), + initialized: true, + }) + + tests := []struct { + name string + tickLower int32 + tickUpper int32 + tickCurrent int32 + feeGrowthGlobal0X128 *u256.Uint + feeGrowthGlobal1X128 *u256.Uint + want0 string + want1 string + preconditions func() + }{ + { + name: "returns all for two uninitialized ticks if tick is inside", + tickLower: -2, + tickUpper: 2, + tickCurrent: 0, + feeGrowthGlobal0X128: u256.NewUint(15), + feeGrowthGlobal1X128: u256.NewUint(15), + want0: "15", + want1: "15", + }, + { + name: "returns 0 for two uninitialized ticks if tick is above", + tickLower: -2, + tickUpper: 2, + tickCurrent: 4, + feeGrowthGlobal0X128: u256.NewUint(15), + feeGrowthGlobal1X128: u256.NewUint(15), + want0: "0", + want1: "0", + }, + { + name: "returns 0 for two uninitialized ticks if tick is below", + tickLower: -2, + tickUpper: 2, + tickCurrent: 4, + feeGrowthGlobal0X128: u256.NewUint(15), + feeGrowthGlobal1X128: u256.NewUint(15), + want0: "0", + want1: "0", + }, + { + name: "subtracts upper tick if below", + tickLower: -2, + tickUpper: 2, + tickCurrent: 0, + feeGrowthGlobal0X128: u256.NewUint(15), + feeGrowthGlobal1X128: u256.NewUint(15), + want0: "13", + want1: "12", + preconditions: func() { + setTick( + t, + pool, + 2, + u256.NewUint(2), + u256.NewUint(3), + u256.NewUint(0), + i256.NewInt(0), + u256.NewUint(0), + 0, + 0, + true, + ) + }, + }, + { + name: "subtracts lower tick if below", + tickLower: -2, + tickUpper: 2, + tickCurrent: 0, + feeGrowthGlobal0X128: u256.NewUint(15), + feeGrowthGlobal1X128: u256.NewUint(15), + want0: "13", + want1: "12", + preconditions: func() { + deleteTick(t, pool, 2) // delete tick from previous test + setTick( + t, + pool, + -2, + u256.NewUint(2), + u256.NewUint(3), + u256.NewUint(0), + i256.NewInt(0), + u256.NewUint(0), + 0, + 0, + true, + ) + }, + }, + { + name: "subtracts upper and lower tick if inside", + tickLower: -2, + tickUpper: 2, + tickCurrent: 0, + feeGrowthGlobal0X128: u256.NewUint(15), + feeGrowthGlobal1X128: u256.NewUint(15), + want0: "9", + want1: "11", + preconditions: func() { + // we already have tick -2 + setTick( + t, + pool, + 2, + u256.NewUint(4), + u256.NewUint(1), + u256.NewUint(0), + i256.NewInt(0), + u256.NewUint(0), + 0, + 0, + true, + ) + }, + }, + { + name: "works correctly with overflow on inside tick", + tickLower: -2, + tickUpper: 2, + tickCurrent: 0, + feeGrowthGlobal0X128: u256.NewUint(15), + feeGrowthGlobal1X128: u256.NewUint(15), + want0: "16", + want1: "13", + preconditions: func() { + deleteTick(t, pool, 2) + deleteTick(t, pool, -2) + setTick( + t, + pool, + -2, + u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639932"), // max uint256 - 3 + u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639933"), // max uint256 - 2 + u256.NewUint(0), + i256.NewInt(0), + u256.NewUint(0), + 0, + 0, + true, + ) + setTick( + t, + pool, + 2, + u256.NewUint(3), + u256.NewUint(5), + u256.NewUint(0), + i256.NewInt(0), + u256.NewUint(0), + 0, + 0, + true, + ) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.preconditions != nil { + tt.preconditions() + } + got0, got1 := pool.getFeeGrowthInside( + tt.tickLower, + tt.tickUpper, + tt.tickCurrent, + tt.feeGrowthGlobal0X128, + tt.feeGrowthGlobal1X128, + ) + if got0.ToString() != tt.want0 || got1.ToString() != tt.want1 { + t.Errorf("getFeeGrowthInside() = (%v, %v), want (%v, %v)", + got0.ToString(), got1.ToString(), tt.want0, tt.want1) + } + }) + } +} + +func TestTickUpdate(t *testing.T) { + pool := &Pool{ + ticks: avl.NewTree(), + } + + tests := []struct { + name string + preconditions func() + tick int32 + tickCurrent int32 + liquidityDelta *i256.Int + feeGrowthGlobal0X128 *u256.Uint + feeGrowthGlobal1X128 *u256.Uint + upper bool + maxLiquidity *u256.Uint + wantFlipped bool + shouldPanic bool + verify func() + }{ + { + name: "flips from zero to non zero", + tick: 0, + tickCurrent: 0, + liquidityDelta: i256.One(), + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + upper: false, + maxLiquidity: u256.NewUint(3), + wantFlipped: true, + }, + { + name: "does not flip from nonzero to greater nonzero", + preconditions: func() { + deleteTick(t, pool, 0) + pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) + }, + tick: 0, + tickCurrent: 0, + liquidityDelta: i256.One(), + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + upper: false, + maxLiquidity: u256.NewUint(3), + wantFlipped: false, + }, + { + name: "flips from nonzero to zero", + preconditions: func() { + deleteTick(t, pool, 0) + pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) + }, + tick: 0, + tickCurrent: 0, + liquidityDelta: i256.NewInt(-1), + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + upper: false, + maxLiquidity: u256.NewUint(3), + wantFlipped: true, + }, + { + name: "does not flip from nonzero to lesser nonzero", + preconditions: func() { + deleteTick(t, pool, 0) + pool.tickUpdate(0, 0, i256.NewInt(2), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) + }, + tick: 0, + tickCurrent: 0, + liquidityDelta: i256.NewInt(-1), + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + upper: false, + maxLiquidity: u256.NewUint(3), + wantFlipped: false, + }, + { + name: "reverts if total liquidity gross is greater than max", + preconditions: func() { + deleteTick(t, pool, 0) + pool.tickUpdate(0, 0, i256.NewInt(2), u256.Zero(), u256.Zero(), false, u256.NewUint(3)) + pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), true, u256.NewUint(3)) + + }, + tick: 0, + tickCurrent: 0, + liquidityDelta: i256.One(), + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + upper: false, + maxLiquidity: u256.NewUint(3), + wantFlipped: false, + shouldPanic: true, + }, + { + name: "nets the liquidity based on upper flag", + preconditions: func() { + pool.tickUpdate(0, 0, i256.NewInt(2), u256.Zero(), u256.Zero(), false, u256.NewUint(10)) + pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), true, u256.NewUint(10)) + pool.tickUpdate(0, 0, i256.NewInt(3), u256.Zero(), u256.Zero(), true, u256.NewUint(10)) + pool.tickUpdate(0, 0, i256.One(), u256.Zero(), u256.Zero(), false, u256.NewUint(10)) + + }, + tick: 0, + tickCurrent: 0, + liquidityDelta: i256.One(), + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + upper: false, + maxLiquidity: u256.NewUint(3), + wantFlipped: false, + shouldPanic: true, + }, + { + name: "reverts on overflow liquidity gross", + preconditions: func() { + pool.tickUpdate(0, 0, i256.MustFromDecimal("170141183460469231731687303715884105726"), u256.Zero(), u256.Zero(), false, u256.MustFromDecimal("340282366920938463463374607431768211455")) + }, + tick: 0, + tickCurrent: 0, + liquidityDelta: i256.MustFromDecimal("170141183460469231731687303715884105726"), // (maxUint128 / 2) + 1 + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + upper: false, + maxLiquidity: u256.MustFromDecimal("340282366920938463463374607431768211455"), // maxUint128 + wantFlipped: false, + shouldPanic: true, + }, + { + name: "assumes all growth happens below ticks lte current tick", + preconditions: func() { + deleteTick(t, pool, 0) + deleteTick(t, pool, 1) + pool.tickUpdate(1, 1, i256.One(), u256.One(), u256.NewUint(2), false, u256.MustFromDecimal("340282366920938463463374607431768211455")) + }, + tick: 0, + tickCurrent: 0, + liquidityDelta: i256.MustFromDecimal("170141183460469231731687303715884105726"), + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + upper: false, + maxLiquidity: u256.MustFromDecimal("340282366920938463463374607431768211455"), + wantFlipped: false, + shouldPanic: false, + verify: func() { + info := pool.mustGetTick(1) + uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "1") + uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "2") + }, + }, + { + name: "does not set any growth fields if tick is already initialized", + preconditions: func() { + pool.tickUpdate(1, 1, i256.One(), u256.One(), u256.NewUint(2), false, u256.MustFromDecimal("340282366920938463463374607431768211455")) + pool.tickUpdate(1, 1, i256.One(), u256.NewUint(6), u256.NewUint(7), false, u256.MustFromDecimal("340282366920938463463374607431768211455")) + }, + tick: 1, + tickCurrent: 1, + liquidityDelta: i256.One(), + feeGrowthGlobal0X128: u256.NewUint(6), + feeGrowthGlobal1X128: u256.NewUint(7), + upper: false, + maxLiquidity: u256.MustFromDecimal("340282366920938463463374607431768211455"), + wantFlipped: false, + shouldPanic: false, + verify: func() { + info := pool.mustGetTick(1) + uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "1") + uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "2") + }, + }, + { + name: "does not set any growth fields for ticks gt current tick", + tick: 2, + tickCurrent: 1, + liquidityDelta: i256.One(), + feeGrowthGlobal0X128: u256.NewUint(1), + feeGrowthGlobal1X128: u256.NewUint(2), + upper: false, + maxLiquidity: u256.MustFromDecimal("340282366920938463463374607431768211455"), + wantFlipped: false, + shouldPanic: false, + verify: func() { + info := pool.getTick(2) + uassert.Equal(t, info.feeGrowthOutside0X128.ToString(), "0") + uassert.Equal(t, info.feeGrowthOutside1X128.ToString(), "0") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.preconditions != nil { + tt.preconditions() + } + + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic, but got nil") + } + }() + + pool.tickUpdate( + tt.tick, + tt.tickCurrent, + tt.liquidityDelta, + tt.feeGrowthGlobal0X128, + tt.feeGrowthGlobal1X128, + tt.upper, + tt.maxLiquidity, + ) + } else { + if tt.verify != nil { + tt.verify() + } else { + gotFlipped := pool.tickUpdate( + tt.tick, + tt.tickCurrent, + tt.liquidityDelta, + tt.feeGrowthGlobal0X128, + tt.feeGrowthGlobal1X128, + tt.upper, + tt.maxLiquidity, + ) + if gotFlipped != tt.wantFlipped { + t.Errorf("tickUpdate() flipped = %v, want %v", gotFlipped, tt.wantFlipped) + } + } + } + }) + } +} + +func TestTickCross(t *testing.T) { + pool := &Pool{ + ticks: avl.NewTree(), + } + + // Setup initial tick state + pool.ticks.Set("100", TickInfo{ + liquidityGross: u256.NewUint(1000), + liquidityNet: i256.NewInt(500), + feeGrowthOutside0X128: u256.NewUint(10), + feeGrowthOutside1X128: u256.NewUint(15), + initialized: true, + }) + + tests := []struct { + name string + tick int32 + feeGrowthGlobal0X128 *u256.Uint + feeGrowthGlobal1X128 *u256.Uint + wantLiquidityNet string + }{ + { + name: "cross tick upwards", + tick: 100, + feeGrowthGlobal0X128: u256.NewUint(20), + feeGrowthGlobal1X128: u256.NewUint(25), + wantLiquidityNet: "500", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotLiquidityNet := pool.tickCross( + tt.tick, + tt.feeGrowthGlobal0X128, + tt.feeGrowthGlobal1X128, + ) + if gotLiquidityNet.ToString() != tt.wantLiquidityNet { + t.Errorf("tickCross() liquidityNet = %v, want %v", + gotLiquidityNet.ToString(), tt.wantLiquidityNet) + } + }) + } +} + +func TestGetTick(t *testing.T) { + pool := &Pool{ + ticks: avl.NewTree(), + } + + // Setup a tick + expectedTick := TickInfo{ + liquidityGross: u256.NewUint(1000), + liquidityNet: i256.NewInt(500), + feeGrowthOutside0X128: u256.NewUint(10), + feeGrowthOutside1X128: u256.NewUint(15), + initialized: true, + } + pool.setTick(50, expectedTick) + + tests := []struct { + name string + tick int32 + wantInit bool + }{ + { + name: "get existing tick", + tick: 50, + wantInit: true, + }, + { + name: "get non-existing tick", + tick: 100, + wantInit: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := pool.getTick(tt.tick) + if got.initialized != tt.wantInit { + t.Errorf("getTick() initialized = %v, want %v", got.initialized, tt.wantInit) + } + }) + } +} + +func TestGetFeeGrowthBelowX128(t *testing.T) { + // Setup test data + globalFeeGrowth0 := u256.NewUint(1000) // Global fee growth for token 0 + globalFeeGrowth1 := u256.NewUint(2000) // Global fee growth for token 1 + + lowerTick := TickInfo{ + feeGrowthOutside0X128: u256.NewUint(300), // fee growth outside for token 0 + feeGrowthOutside1X128: u256.NewUint(500), // fee growth outside for token 1 + } + + tests := []struct { + name string + tickLower int32 + tickCurrent int32 + expectedFeeGrowth0 *u256.Uint + expectedFeeGrowth1 *u256.Uint + }{ + { + name: "tickCurrent >= tickLower - Return feeGrowthOutside directly", + tickLower: 100, + tickCurrent: 100, + expectedFeeGrowth0: lowerTick.feeGrowthOutside0X128, + expectedFeeGrowth1: lowerTick.feeGrowthOutside1X128, + }, + { + name: "tickCurrent > tickLower - Return feeGrowthOutside directly", + tickLower: 50, + tickCurrent: 100, + expectedFeeGrowth0: lowerTick.feeGrowthOutside0X128, + expectedFeeGrowth1: lowerTick.feeGrowthOutside1X128, + }, + { + name: "tickCurrent < tickLower - Subtract feeGrowthOutside from global", + tickLower: 100, + tickCurrent: 50, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the function + feeGrowth0, feeGrowth1 := getFeeGrowthBelowX128( + tt.tickLower, tt.tickCurrent, + globalFeeGrowth0, globalFeeGrowth1, + lowerTick, + ) + + // Assertions + uassert.True(t, feeGrowth0.Eq(tt.expectedFeeGrowth0), + "Expected feeGrowth0: %s, got: %s", tt.expectedFeeGrowth0.ToString(), feeGrowth0.ToString()) + uassert.True(t, feeGrowth1.Eq(tt.expectedFeeGrowth1), + "Expected feeGrowth1: %s, got: %s", tt.expectedFeeGrowth1.ToString(), feeGrowth1.ToString()) + }) + } +} + +func TestGetFeeGrowthAboveX128(t *testing.T) { + // Setup test data + globalFeeGrowth0 := u256.NewUint(1000) // Global fee growth for token 0 + globalFeeGrowth1 := u256.NewUint(2000) // Global fee growth for token 1 + + upperTick := TickInfo{ + feeGrowthOutside0X128: u256.NewUint(300), // Fee growth outside for token 0 + feeGrowthOutside1X128: u256.NewUint(500), // Fee growth outside for token 1 + } + + tests := []struct { + name string + tickUpper int32 + tickCurrent int32 + expectedFeeGrowth0 *u256.Uint + expectedFeeGrowth1 *u256.Uint + }{ + { + name: "tickCurrent < tickUpper - Return feeGrowthOutside directly", + tickUpper: 100, + tickCurrent: 50, + expectedFeeGrowth0: upperTick.feeGrowthOutside0X128, // 300 + expectedFeeGrowth1: upperTick.feeGrowthOutside1X128, // 500 + }, + { + name: "tickCurrent >= tickUpper - Subtract feeGrowthOutside from global", + tickUpper: 100, + tickCurrent: 150, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + { + name: "tickCurrent == tickUpper - Subtract feeGrowthOutside from global", + tickUpper: 100, + tickCurrent: 100, + expectedFeeGrowth0: u256.NewUint(700), // 1000 - 300 + expectedFeeGrowth1: u256.NewUint(1500), // 2000 - 500 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the function + feeGrowth0, feeGrowth1 := getFeeGrowthAboveX128( + tt.tickUpper, tt.tickCurrent, + globalFeeGrowth0, globalFeeGrowth1, + upperTick, + ) + + // Assertions + uassert.True(t, feeGrowth0.Eq(tt.expectedFeeGrowth0), + "Expected feeGrowth0: %s, got: %s", tt.expectedFeeGrowth0.ToString(), feeGrowth0.ToString()) + uassert.True(t, feeGrowth1.Eq(tt.expectedFeeGrowth1), + "Expected feeGrowth1: %s, got: %s", tt.expectedFeeGrowth1.ToString(), feeGrowth1.ToString()) + }) + } +} + +func setTick( + t *testing.T, + pool *Pool, + tick int32, + feeGrowthOutside0X128 *u256.Uint, + feeGrowthOutside1X128 *u256.Uint, + liquidityGross *u256.Uint, + liquidityNet *i256.Int, + secondsPerLiquidityOutsideX128 *u256.Uint, + tickCumulativeOutside int64, + secondsOutside uint32, + initialized bool, +) { + t.Helper() + + info := pool.getTick(tick) + info.feeGrowthOutside0X128 = feeGrowthOutside0X128 + info.feeGrowthOutside1X128 = feeGrowthOutside1X128 + info.liquidityGross = liquidityGross + info.liquidityNet = liquidityNet + info.secondsPerLiquidityOutsideX128 = secondsPerLiquidityOutsideX128 + info.tickCumulativeOutside = tickCumulativeOutside + info.secondsOutside = secondsOutside + info.initialized = initialized + + pool.setTick(tick, info) +} + +func deleteTick(t *testing.T, pool *Pool, tick int32) { + t.Helper() + pool.deleteTick(tick) +} diff --git a/contract/r/gnoswap/pool/type.gno b/contract/r/gnoswap/pool/type.gno new file mode 100644 index 000000000..439277b9d --- /dev/null +++ b/contract/r/gnoswap/pool/type.gno @@ -0,0 +1,463 @@ +package pool + +import ( + "std" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +type Slot0 struct { + sqrtPriceX96 *u256.Uint // current price of the pool as a sqrt(token1/token0) Q96 value + tick int32 // current tick of the pool, i.e according to the last tick transition that was run + feeProtocol uint8 // protocol fee for both tokens of the pool + unlocked bool // whether the pool is currently locked to reentrancy +} + +func newSlot0( + sqrtPriceX96 *u256.Uint, + tick int32, + feeProtocol uint8, + unlocked bool, +) Slot0 { + return Slot0{ + sqrtPriceX96: sqrtPriceX96, + tick: tick, + feeProtocol: feeProtocol, + unlocked: unlocked, + } +} + +type Balances struct { + // current balance of the pool in token0/token1 + token0 *u256.Uint + token1 *u256.Uint +} + +func newBalances() Balances { + return Balances{ + token0: u256.Zero(), + token1: u256.Zero(), + } +} + +type ProtocolFees struct { + // current protocol fees of the pool in token0/token1 + token0 *u256.Uint + token1 *u256.Uint +} + +func newProtocolFees() ProtocolFees { + return ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + } +} + +// ModifyPositionParams repersents the parameters for modifying a liquidity position. +// This structure is used internally both `Mint` and `Burn` operation to manage +// the liquidity positions. +type ModifyPositionParams struct { + // owner is the address that owns the position + owner std.Address + + // tickLower and atickUpper define the price range + // The actual price range is calculated as 1.0001^tick + // This allows for precision in price range while using integer math. + + tickLower int32 // lower tick of the position + tickUpper int32 // upper tick of the position + + // liquidityDelta represents the change in liquidity + // Positive for minting, negative for burning + liquidityDelta *i256.Int +} + +// newModifyPositionParams creates a new `ModifyPositionParams` instance. +// This is used to preare parameters for the `modifyPosition` function, +// which handles both minting and burning of liquidity positions. +// +// Parameters: +// - owner: address that will own (or owns) the position +// - tickLower: lower tick bound of the position +// - tickUpper: upper tick bound of the position +// - liquidityDelta: amount of liquidity to add (positive) or remove (negative) +// +// The tick parameters represent prices as powers of 1.0001: +// - actual_price = 1.0001^tick +// - For example, tick = 100 means price = 1.0001^100 +// +// Returns: +// - ModifyPositionParams: a new instance of ModifyPositionParams +func newModifyPositionParams( + owner std.Address, + tickLower int32, + tickUpper int32, + liquidityDelta *i256.Int, +) ModifyPositionParams { + return ModifyPositionParams{ + owner: owner, + tickLower: tickLower, + tickUpper: tickUpper, + liquidityDelta: liquidityDelta, + } +} + +// SwapCache holds data that remains constant throughout a swap. +type SwapCache struct { + feeProtocol uint8 // protocol fee for the input token + liquidityStart *u256.Uint // liquidity at the beginning of the swap +} + +func newSwapCache( + feeProtocol uint8, + liquidityStart *u256.Uint, +) SwapCache { + return SwapCache{ + feeProtocol: feeProtocol, + liquidityStart: liquidityStart, + } +} + +// SwapState tracks the changing values during a swap. +// This type helps manage the state transiktions that occur as the swap progresses +// accross different price ranges. +type SwapState struct { + amountSpecifiedRemaining *i256.Int // amount remaining to be swapped in/out of the input/output token + amountCalculated *i256.Int // amount already swapped out/in of the output/input token + sqrtPriceX96 *u256.Uint // current sqrt(price) + tick int32 // tick associated with the current sqrt(price) + feeGrowthGlobalX128 *u256.Uint // global fee growth of the input token + protocolFee *u256.Uint // amount of input token paid as protocol fee + liquidity *u256.Uint // current liquidity in range +} + +func newSwapState( + amountSpecifiedRemaining *i256.Int, + feeGrowthGlobalX128 *u256.Uint, + liquidity *u256.Uint, + slot0 Slot0, +) SwapState { + return SwapState{ + amountSpecifiedRemaining: amountSpecifiedRemaining, + amountCalculated: i256.Zero(), + sqrtPriceX96: slot0.sqrtPriceX96, + tick: slot0.tick, + feeGrowthGlobalX128: feeGrowthGlobalX128, + protocolFee: u256.Zero(), + liquidity: liquidity, + } +} + +func (s *SwapState) SetSqrtPriceX96(sqrtPriceX96 string) { + s.sqrtPriceX96 = u256.MustFromDecimal(sqrtPriceX96) +} + +func (s *SwapState) SetTick(tick int32) { + s.tick = tick +} + +func (s *SwapState) SetFeeGrowthGlobalX128(feeGrowthGlobalX128 *u256.Uint) { + s.feeGrowthGlobalX128 = feeGrowthGlobalX128 +} + +func (s *SwapState) SetProtocolFee(fee *u256.Uint) { + s.protocolFee = fee +} + +// StepComputations holds intermediate values used during a single step of a swap. +// Each step represents movement from the current tick to the next initialized tick +// or the target price, whichever comes first. +type StepComputations struct { + sqrtPriceStartX96 *u256.Uint // price at the beginning of the step + tickNext int32 // next tick to swap to from the current tick in the swap direction + initialized bool // whether tickNext is initialized + sqrtPriceNextX96 *u256.Uint // sqrt(price) for the next tick (token1/token0) Q96 + amountIn *u256.Uint // how much being swapped in this step + amountOut *u256.Uint // how much is being swapped out in this step + feeAmount *u256.Uint // how much fee is being paid in this step +} + +// init initializes the computation for a single swap step +func (step *StepComputations) initSwapStep(state SwapState, pool *Pool, zeroForOne bool) { + step.sqrtPriceStartX96 = state.sqrtPriceX96 + step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( + state.tick, + pool.tickSpacing, + zeroForOne, + ) + + // prevent overshoot the min/max tick + step.clampTickNext() + + // get the price for the next tick + step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) +} + +// clampTickNext ensures that `tickNext` stays within the min, max tick boundaries +// as the tick bitmap is not aware of these bounds +func (step *StepComputations) clampTickNext() { + if step.tickNext < consts.MIN_TICK { + step.tickNext = consts.MIN_TICK + } else if step.tickNext > consts.MAX_TICK { + step.tickNext = consts.MAX_TICK + } +} + +type PositionInfo struct { + liquidity *u256.Uint // amount of liquidity owned by this position + + // Fee growth per unit of liquidity as of the last update + // Used to calculate uncollected fees for token0 + feeGrowthInside0LastX128 *u256.Uint + + // Fee growth per unit of liquidity as of the last update + // Used to calculate uncollected fees for token1 + feeGrowthInside1LastX128 *u256.Uint + + // accumulated fees in token0 waiting to be collected + tokensOwed0 *u256.Uint + + // accumulated fees in token1 waiting to be collected + tokensOwed1 *u256.Uint +} + +// valueOrZero initializes nil fields in PositionInfo to zero. +// +// This function ensures that all numeric fields in the PositionInfo struct are not nil. +// If a field is nil, it is replaced with a zero value, maintaining consistency and preventing +// potential null pointer issues during calculations. +// +// Fields affected: +// - liquidity: The liquidity amount associated with the position. +// - feeGrowthInside0LastX128: Fee growth for token 0 inside the tick range, last recorded value. +// - feeGrowthInside1LastX128: Fee growth for token 1 inside the tick range, last recorded value. +// - tokensOwed0: The amount of token 0 owed to the position owner. +// - tokensOwed1: The amount of token 1 owed to the position owner. +// +// Behavior: +// - If a field is nil, it is set to its equivalent zero value. +// - If a field already has a value, it remains unchanged. +// +// Example: +// +// position := &PositionInfo{} +// position.valueOrZero() +// fmt.Println(position.liquidity) // Output: 0 +// +// Notes: +// - This function is useful for ensuring numeric fields are properly initialized +// before performing operations or calculations. +// - Prevents runtime errors caused by nil values. +func (p *PositionInfo) valueOrZero() { + p.liquidity = p.liquidity.NilToZero() + p.feeGrowthInside0LastX128 = p.feeGrowthInside0LastX128.NilToZero() + p.feeGrowthInside1LastX128 = p.feeGrowthInside1LastX128.NilToZero() + p.tokensOwed0 = p.tokensOwed0.NilToZero() + p.tokensOwed1 = p.tokensOwed1.NilToZero() +} + +// TickInfo stores information about a specific tick in the pool. +// TIcks represent discrete price points that can be used as boundaries for positions. +type TickInfo struct { + liquidityGross *u256.Uint // total position liquidity that references this tick + liquidityNet *i256.Int // amount of net liquidity added (subtracted) when tick is crossed from left to right (right to left) + + // fee growth per unit of liquidity on the _other_ side of this tick (relative to the current tick) + // only has relative meaning, not absolute — the value depends on when the tick is initialized + feeGrowthOutside0X128 *u256.Uint + feeGrowthOutside1X128 *u256.Uint + + tickCumulativeOutside int64 // cumulative tick value on the other side of the tick + + // the seconds per unit of liquidity on the _other_ side of this tick (relative to the current tick) + // only has relative meaning, not absolute — the value depends on when the tick is initialized + secondsPerLiquidityOutsideX128 *u256.Uint + + // the seconds spent on the other side of the tick (relative to the current tick) + // only has relative meaning, not absolute — the value depends on when the tick is initialized + secondsOutside uint32 + + initialized bool // whether the tick is initialized +} + +// valueOrZero ensures that all fields of TickInfo are valid by setting nil fields to zero, +// while retaining existing values if they are not nil. +// This function updates the TickInfo struct to replace any nil values in its fields +// with their respective zero values, ensuring data consistency. +// +// Behavior: +// - If a field is nil, it is replaced with its zero value. +// - If a field already has a valid value, the value remains unchanged. +// +// Fields: +// - liquidityGross: Gross liquidity for the tick, set to zero if nil, otherwise retains its value. +// - liquidityNet: Net liquidity for the tick, set to zero if nil, otherwise retains its value. +// - feeGrowthOutside0X128: Accumulated fee growth for token0 outside the tick, set to zero if nil, otherwise retains its value. +// - feeGrowthOutside1X128: Accumulated fee growth for token1 outside the tick, set to zero if nil, otherwise retains its value. +// - secondsPerLiquidityOutsideX128: Time per liquidity outside the tick, set to zero if nil, otherwise retains its value. +// +// Use Case: +// This function ensures all numeric fields in TickInfo are non-nil and have valid values, +// preventing potential runtime errors caused by nil values during operations like arithmetic or comparisons. +func (t *TickInfo) valueOrZero() { + t.liquidityGross = t.liquidityGross.NilToZero() + t.liquidityNet = t.liquidityNet.NilToZero() + t.feeGrowthOutside0X128 = t.feeGrowthOutside0X128.NilToZero() + t.feeGrowthOutside1X128 = t.feeGrowthOutside1X128.NilToZero() + t.secondsPerLiquidityOutsideX128 = t.secondsPerLiquidityOutsideX128.NilToZero() +} + +// type Pool describes a single Pool's state +// A pool is identificed with a unique key (token0, token1, fee), where token0 < token1 +type Pool struct { + // token0/token1 path of the pool + token0Path string + token1Path string + + balances Balances // balances of the pool + + fee uint32 // fee tier of the pool + + tickSpacing int32 // spacing between ticks + + maxLiquidityPerTick *u256.Uint // the maximum amount of liquidity that can be added per tick + + slot0 Slot0 + + feeGrowthGlobal0X128 *u256.Uint // uint256 + feeGrowthGlobal1X128 *u256.Uint // uint256 + + protocolFees ProtocolFees + + liquidity *u256.Uint // total amount of liquidity in the pool + + ticks *avl.Tree // tick(int32) -> TickInfo + + tickBitmaps *avl.Tree // tick(wordPos)(int16) -> bitMap(tickWord ^ mask)(*u256.Uint) + + positions *avl.Tree // maps the key (caller, lower tick, upper tick) to a unique position + + tickCrossHook func(poolPath string, tickId int32, zeroForOne bool) +} + +var tickCrossHook func(poolPath string, tickId int32, zeroForOne bool) + +func SetTickCrossHook(hook func(poolPath string, tickId int32, zeroForOne bool)) { + if tickCrossHook != nil { + panic("tickCrossHook already set") + } + tickCrossHook = hook +} + +func newPool(poolInfo *createPoolParams) *Pool { + maxLiquidityPerTick := calculateMaxLiquidityPerTick(poolInfo.tickSpacing) + tick := common.TickMathGetTickAtSqrtRatio(poolInfo.SqrtPriceX96()) + slot0 := newSlot0(poolInfo.SqrtPriceX96(), tick, slot0FeeProtocol, true) + + return &Pool{ + token0Path: poolInfo.Token0Path(), + token1Path: poolInfo.Token1Path(), + balances: newBalances(), + fee: poolInfo.Fee(), + tickSpacing: poolInfo.TickSpacing(), + maxLiquidityPerTick: maxLiquidityPerTick, + slot0: slot0, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + protocolFees: newProtocolFees(), + liquidity: u256.Zero(), + ticks: avl.NewTree(), + tickBitmaps: avl.NewTree(), + positions: avl.NewTree(), + } +} + +func (p *Pool) PoolPath() string { + return GetPoolPath(p.token0Path, p.token1Path, p.fee) +} + +func (p *Pool) Token0Path() string { + return p.token0Path +} + +func (p *Pool) Token1Path() string { + return p.token1Path +} + +func (p *Pool) Fee() uint32 { + return p.fee +} + +func (p *Pool) BalanceToken0() *u256.Uint { + return p.balances.token0 +} + +func (p *Pool) BalanceToken1() *u256.Uint { + return p.balances.token1 +} + +func (p *Pool) TickSpacing() int32 { + return p.tickSpacing +} + +func (p *Pool) MaxLiquidityPerTick() *u256.Uint { + return p.maxLiquidityPerTick +} + +func (p *Pool) Slot0() Slot0 { + return p.slot0 +} + +func (p *Pool) Slot0SqrtPriceX96() *u256.Uint { + return p.slot0.sqrtPriceX96 +} + +func (p *Pool) Slot0Tick() int32 { + return p.slot0.tick +} + +func (p *Pool) Slot0FeeProtocol() uint8 { + return p.slot0.feeProtocol +} + +func (p *Pool) Slot0Unlocked() bool { + return p.slot0.unlocked +} + +func (p *Pool) FeeGrowthGlobal0X128() *u256.Uint { + return p.feeGrowthGlobal0X128 +} + +func (p *Pool) FeeGrowthGlobal1X128() *u256.Uint { + return p.feeGrowthGlobal1X128 +} + +func (p *Pool) ProtocolFeesToken0() *u256.Uint { + return p.protocolFees.token0 +} + +func (p *Pool) ProtocolFeesToken1() *u256.Uint { + return p.protocolFees.token1 +} + +func (p *Pool) Liquidity() *u256.Uint { + return p.liquidity +} + +func mustGetPool(poolPath string) *Pool { + pool := GetPoolFromPoolPath(poolPath) + if pool == nil { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), + )) + } + return pool +} diff --git a/contract/r/gnoswap/pool/utils.gno b/contract/r/gnoswap/pool/utils.gno new file mode 100644 index 000000000..724657134 --- /dev/null +++ b/contract/r/gnoswap/pool/utils.gno @@ -0,0 +1,273 @@ +package pool + +import ( + "std" + "strconv" + + "gno.land/p/demo/ufmt" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +// safeConvertToUint64 safely converts a *u256.Uint value to a uint64, ensuring no overflow. +// +// This function attempts to convert the given *u256.Uint value to a uint64. If the value exceeds +// the maximum allowable range for uint64 (`2^64 - 1`), it triggers a panic with a descriptive error message. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be converted. +// +// Returns: +// - uint64: The converted value if it falls within the uint64 range. +// +// Panics: +// - If the `value` exceeds the range of uint64, the function will panic with an error indicating +// the overflow and the original value. +// +// Notes: +// - This function uses the `Uint64WithOverflow` method to detect overflow during the conversion. +// - It is essential to validate large values before calling this function to avoid unexpected panics. +// +// Example: +// safeValue := safeConvertToUint64(u256.MustFromDecimal("18446744073709551615")) // Valid conversion +// safeConvertToUint64(u256.MustFromDecimal("18446744073709551616")) // Panics due to overflow +func safeConvertToUint64(value *u256.Uint) uint64 { + res, overflow := value.Uint64WithOverflow() + if overflow { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows uint64 range", + errOutOfRange, value.ToString())) + } + return res +} + +// safeConvertToInt128 safely converts a *u256.Uint value to an *i256.Int, ensuring it does not exceed the int128 range. +// +// This function converts an unsigned 256-bit integer (*u256.Uint) into a signed 256-bit integer (*i256.Int). +// It checks whether the resulting value falls within the valid range of int128 (`-2^127` to `2^127 - 1`). +// If the value exceeds the maximum allowable int128 range, it triggers a panic with a descriptive error message. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be converted. +// +// Returns: +// - *i256.Int: The converted value if it falls within the int128 range. +// +// Panics: +// - If the converted value exceeds the maximum int128 value (`2^127 - 1`), the function will panic with an +// error message indicating the overflow and the original value. +// +// Notes: +// - The function uses `i256.FromUint256` to perform the conversion. +// - The constant `MAX_INT128` is used to define the upper bound of the int128 range (`170141183460469231731687303715884105727`). +// +// Example: +// validInt128 := safeConvertToInt128(u256.MustFromDecimal("170141183460469231731687303715884105727")) // Valid conversion +// safeConvertToInt128(u256.MustFromDecimal("170141183460469231731687303715884105728")) // Panics due to overflow +func safeConvertToInt128(value *u256.Uint) *i256.Int { + liquidityDelta := i256.FromUint256(value) + if liquidityDelta.Gt(i256.MustFromDecimal(consts.MAX_INT128)) { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows int128 range", + errOverFlow, value.ToString())) + } + return liquidityDelta +} + +// toUint128 ensures a *u256.Uint value fits within the uint128 range. +// +// This function validates that the given `value` is properly initialized and checks whether +// it exceeds the maximum value of uint128. If the value exceeds the uint128 range, +// it applies a masking operation to truncate the value to fit within the uint128 limit. +// +// Parameters: +// - value: *u256.Uint, the value to be checked and possibly truncated. +// +// Returns: +// - *u256.Uint: A value guaranteed to fit within the uint128 range. +// +// Notes: +// - The mask ensures that only the lower 128 bits of the value are retained. +// - If the input value is already within the uint128 range, it remains unchanged. +// - MAX_UINT128 is a constant representing `2^128 - 1`. +func toUint128(value *u256.Uint) *u256.Uint { + assertOnlyInitializedUint256(value) + if value.Gt(u256.MustFromDecimal(consts.MAX_UINT128)) { + mask := new(u256.Uint).Lsh(u256.One(), consts.Q128_RESOLUTION) + mask = mask.Sub(mask, u256.One()) + value = value.And(value, mask) + } + return value +} + +// u256Min returns the smaller of two *u256.Uint values. +// +// This function compares two unsigned 256-bit integers and returns the smaller of the two. +// If `num1` is less than `num2`, it returns `num1`; otherwise, it returns `num2`. +// +// Parameters: +// - num1 (*u256.Uint): The first unsigned 256-bit integer. +// - num2 (*u256.Uint): The second unsigned 256-bit integer. +// +// Returns: +// - *u256.Uint: The smaller of `num1` and `num2`. +// +// Notes: +// - This function uses the `Lt` (less than) method of `*u256.Uint` to perform the comparison. +// - The function assumes both input values are non-nil. If nil inputs are possible in the usage context, +// additional validation may be needed. +// +// Example: +// smaller := u256Min(u256.MustFromDecimal("10"), u256.MustFromDecimal("20")) // Returns 10 +// smaller := u256Min(u256.MustFromDecimal("30"), u256.MustFromDecimal("20")) // Returns 20 +func u256Min(num1, num2 *u256.Uint) *u256.Uint { + if num1.Lt(num2) { + return num1 + } + return num2 +} + +// derivePkgAddr derives the Realm address from it's pkgPath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) +} + +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrevAsString returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { + prev := getPrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// checkTransferError checks transfer error. +func checkTransferError(err error) { + if err != nil { + panic(addDetailToError( + errTransferFailed, + err.Error(), + )) + } +} + +// checkOverFlowInt128 checks if the value overflows the int128 range. +func checkOverFlowInt128(value *i256.Int) { + if value.Gt(i256.MustFromDecimal(consts.MAX_INT128)) { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows int128 range", + errOverFlow, value.ToString())) + } +} + +// checkTickSpacing checks if the tick is divisible by the tickSpacing. +func checkTickSpacing(tick, tickSpacing int32) { + if tick%tickSpacing != 0 { + panic(addDetailToError( + errInvalidTickAndTickSpacing, + ufmt.Sprintf("tick(%d) MOD tickSpacing(%d) != 0(%d)", tick, tickSpacing, tick%tickSpacing), + )) + } +} + +// assertOnlyValidAddress panics if the address is invalid. +func assertOnlyValidAddress(addr std.Address) { + if !addr.IsValid() { + panic(addDetailToError( + errInvalidAddress, + ufmt.Sprintf("(%s)", addr), + )) + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertOnlyPositionContract panics if the caller is not the position contract. +func assertOnlyPositionContract() { + caller := getPrevAddr() + if err := common.PositionOnly(caller); err != nil { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only position(%s) can call, called from %s", consts.POSITION_ADDR, caller.String()), + )) + } +} + +// assertOnlyInitializedUint256 panics if the value is nil. +func assertOnlyInitializedUint256(value *u256.Uint) { + if value == nil { + panic(addDetailToError( + errInvalidInput, + "value is nil", + )) + } +} + +// assertOnlyAdmin panics if the caller is not the admin. +func assertOnlyAdmin() { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } +} + +// assertOnlyGovernance panics if the caller is not the governance. +func assertOnlyGovernance() { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err) + } +} + +// assertOnlyRegistered panics if the token is not registered. +func assertOnlyRegistered(tokenPath string) { + common.MustRegistered(tokenPath) +} + +func formatUint(v interface{}) string { + switch v := v.(type) { + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatInt(v interface{}) string { + switch v := v.(type) { + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case int: + return strconv.Itoa(v) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatBool(v bool) string { + return strconv.FormatBool(v) +} diff --git a/contract/r/gnoswap/pool/utils_test.gno b/contract/r/gnoswap/pool/utils_test.gno new file mode 100644 index 000000000..6ef7142a6 --- /dev/null +++ b/contract/r/gnoswap/pool/utils_test.gno @@ -0,0 +1,389 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" +) + +func TestU256Min(t *testing.T) { + tests := []struct { + name string + num1 string + num2 string + expected string + }{ + { + name: "num1 is less than num2", + num1: "1", + num2: "2", + expected: "1", + }, + { + name: "num1 is greater than num2", + num1: "2", + num2: "1", + expected: "1", + }, + { + name: "num1 is equal to num2", + num1: "1", + num2: "1", + expected: "1", + }, + { + name: "compare max u256 with zero", + num1: "115792089237316195423570985008687907853269984665640564039457584007913129639935", + num2: "0", + expected: "0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + num1 := u256.MustFromDecimal(tt.num1) + num2 := u256.MustFromDecimal(tt.num2) + expected := u256.MustFromDecimal(tt.expected) + + uassert.Equal(t, expected.ToString(), u256Min(num1, num2).ToString()) + }) + } +} + +func TestIsUserCall(t *testing.T) { + tests := []struct { + name string + action func() bool + expected bool + }{ + { + name: "called from user", + action: func() bool { + userRealm := std.NewUserRealm(std.Address("user")) + std.TestSetRealm(userRealm) + return isUserCall() + }, + expected: true, + }, + { + name: "called from realm", + action: func() bool { + fromRealm := std.NewCodeRealm("gno.land/r/realm") + std.TestSetRealm(fromRealm) + return isUserCall() + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uassert.Equal(t, tt.expected, tt.action()) + }) + } +} + +func TestGetPrevAsString(t *testing.T) { + tests := []struct { + name string + action func() (string, string) + expectedAddr string + expectedPkgPath string + }{ + { + name: "user call", + action: func() (string, string) { + userRealm := std.NewUserRealm(std.Address("user")) + std.TestSetRealm(userRealm) + return getPrevAsString() + }, + expectedAddr: "user", + expectedPkgPath: "", + }, + { + name: "code call", + action: func() (string, string) { + codeRealm := std.NewCodeRealm("gno.land/r/demo/realm") + std.TestSetRealm(codeRealm) + return getPrevAsString() + }, + expectedAddr: std.DerivePkgAddr("gno.land/r/demo/realm").String(), + expectedPkgPath: "gno.land/r/demo/realm", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, pkgPath := tt.action() + uassert.Equal(t, tt.expectedAddr, addr) + uassert.Equal(t, tt.expectedPkgPath, pkgPath) + }) + } +} + +func TestSafeConvertToUint64(t *testing.T) { + tests := []struct { + name string + value *u256.Uint + wantRes uint64 + wantPanic bool + }{ + {"normal conversion", u256.NewUint(123), 123, false}, + {"overflow", u256.MustFromDecimal(consts.MAX_UINT128), 0, true}, + {"max uint64", u256.NewUint(1<<64 - 1), 1<<64 - 1, false}, + {"zero", u256.NewUint(0), 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } + if tt.wantPanic { + t.Errorf("expected panic, but none occurred") + } + }() + + res := safeConvertToUint64(tt.value) + if res != tt.wantRes { + t.Errorf("safeConvertToUint64() = %v, want %v", res, tt.wantRes) + } + }) + } +} + +func TestSafeConvertToInt128(t *testing.T) { + tests := []struct { + name string + value string + wantRes string + wantPanic bool + }{ + {"normal conversion", "170141183460469231731687303715884105727", "170141183460469231731687303715884105727", false}, + {"overflow", "170141183460469231731687303715884105728", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } + if tt.wantPanic { + t.Errorf("expected panic, but none occurred") + } + }() + + res := safeConvertToInt128(u256.MustFromDecimal(tt.value)) + if res.ToString() != tt.wantRes { + t.Errorf("safeConvertToUint64() = %v, want %v", res, tt.wantRes) + } + }) + } +} + +func TestDerivePkgAddr(t *testing.T) { + var ( + pkgPath = "gno.land/r/gnoswap/v1/position" + ) + tests := []struct { + name string + input string + expected string + }{ + { + name: "Success - derivePkgAddr", + input: pkgPath, + expected: "g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := derivePkgAddr(tc.input) + uassert.Equal(t, got.String(), tc.expected) + }) + } +} + +func TestGetPrevRealm(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prevRealm is User", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevRealm() + uassert.Equal(t, got.Addr().String(), tc.expected[0]) + uassert.Equal(t, got.PkgPath(), tc.expected[1]) + }) + } +} + +func TestGetPrevAddr(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected std.Address + }{ + { + name: "Success - prev Address is User", + originCaller: consts.ADMIN, + expected: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevAddr() + uassert.Equal(t, got.String(), tc.expected.String()) + }) + } +} + +func TestCheckOverFlowInt128(t *testing.T) { + tests := []struct { + name string + input *i256.Int + shouldPanic bool + expected string + }{ + { + name: "Valid value within int128 range", + input: i256.MustFromDecimal("1"), + shouldPanic: false, + }, + { + name: "Edge case - MAX_INT128", + input: i256.MustFromDecimal(consts.MAX_INT128), + shouldPanic: false, + }, + { + name: "Overflow case - exceeds MAX_INT128", + input: i256.MustFromDecimal(consts.MAX_INT256), + shouldPanic: true, + expected: "[GNOSWAP-POOL-026] overflow: amount(170141183460469231731687303715884105728) overflows int128 range", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } else if tt.shouldPanic { + uassert.Equal(t, tt.expected, r) + t.Errorf("Expected panic but none occurred") + } + }() + checkOverFlowInt128(tt.input) + }) + } +} + +func TestCheckTickSpacing(t *testing.T) { + tests := []struct { + name string + tick int32 + tickSpacing int32 + shouldPanic bool + expected string + }{ + { + name: "Valid tick - divisible by tickSpacing", + tick: 120, + tickSpacing: 60, + shouldPanic: false, + }, + { + name: "Valid tick - zero tick", + tick: 0, + tickSpacing: 10, + shouldPanic: false, + }, + { + name: "Invalid tick - not divisible", + tick: 15, + tickSpacing: 10, + shouldPanic: true, + expected: "[GNOSWAP-POOL-022] invalid tick and tick spacing requested || tick(15) MOD tickSpacing(10) != 0(5)", + }, + { + name: "Invalid tick - negative tick", + tick: -35, + tickSpacing: 20, + shouldPanic: true, + expected: "[GNOSWAP-POOL-022] invalid tick and tick spacing requested || tick(-35) MOD tickSpacing(20) != 0(-15)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } else if tt.shouldPanic { + uassert.Equal(t, tt.expected, r) + } + }() + checkTickSpacing(tt.tick, tt.tickSpacing) + }) + } +} + +func TestAssertOnlyValidAddress(t *testing.T) { + tests := []struct { + name string + addr std.Address + expected bool + errorMsg string + }{ + { + name: "Success - valid address", + addr: consts.ADMIN, + expected: true, + }, + { + name: "Failure - invalid address", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", // invalid length + expected: false, + errorMsg: "[GNOSWAP-POOL-023] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddress(tc.addr) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddress(tc.addr) + }) + } + }) + } +} diff --git a/contract/r/gnoswap/position/_helper_test.gno b/contract/r/gnoswap/position/_helper_test.gno new file mode 100644 index 000000000..77db01ae0 --- /dev/null +++ b/contract/r/gnoswap/position/_helper_test.gno @@ -0,0 +1,703 @@ +package position + +import ( + "std" + "testing" + + "gno.land/p/demo/avl" + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + "gno.land/r/onbloc/foo" + "gno.land/r/onbloc/obl" + "gno.land/r/onbloc/qux" + "gno.land/r/onbloc/usdc" + + sr "gno.land/r/gnoswap/v1/staker" +) + +const ( + ugnotDenom string = "ugnot" + ugnotPath string = "ugnot" + wugnotPath string = "gno.land/r/demo/wugnot" + gnsPath string = "gno.land/r/gnoswap/v1/gns" + barPath string = "gno.land/r/onbloc/bar" + bazPath string = "gno.land/r/onbloc/baz" + fooPath string = "gno.land/r/onbloc/foo" + oblPath string = "gno.land/r/onbloc/obl" + quxPath string = "gno.land/r/onbloc/qux" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + fee10000 uint32 = 10000 + maxApprove uint64 = 18446744073709551615 + max_timeout int64 = 9999999999 + maxSqrtPriceLimitX96 string = "1461446703485210103287273052203988822378723970341" + + TIER_1 uint64 = 1 + TIER_2 uint64 = 2 + TIER_3 uint64 = 3 +) + +const ( + // define addresses to use in tests + addr01 = testutils.TestAddress("addr01") + addr02 = testutils.TestAddress("addr02") +) + +var ( + adminAddr = std.Address(consts.ADMIN) + admin = adminAddr + alice = testutils.TestAddress("alice") + bob = testutils.TestAddress("bob") + pool = consts.POOL_ADDR + protocolFee = consts.PROTOCOL_FEE_ADDR + router = consts.ROUTER_ADDR + adminRealm = std.NewUserRealm(admin) + posRealm = std.NewCodeRealm(consts.POSITION_PATH) + rouRealm = std.NewCodeRealm(consts.ROUTER_PATH) + + // addresses used in tests + addrUsedInTest = []std.Address{addr01, addr02} +) + +func InitialisePoolTest(t *testing.T) { + t.Helper() + + ugnotFaucet(t, admin, 100_000_000_000_000) + ugnotDeposit(t, admin, 100_000_000_000_000) + + std.TestSetOrigCaller(admin) + TokenApprove(t, gnsPath, admin, pool, maxApprove) + poolPath := pl.GetPoolPath(wugnotPath, gnsPath, fee3000) + if !pl.DoesPoolPathExist(poolPath) { + pl.CreatePool(wugnotPath, gnsPath, fee3000, "79228162514264337593543950336") + } + + //2. create position + std.TestSetOrigCaller(alice) + TokenFaucet(t, wugnotPath, alice) + TokenFaucet(t, gnsPath, alice) + TokenApprove(t, wugnotPath, alice, pool, uint64(1000)) + TokenApprove(t, gnsPath, alice, pool, uint64(1000)) + MintPosition(t, + wugnotPath, + gnsPath, + fee3000, + int32(1020), + int32(5040), + "1000", + "1000", + "0", + "0", + max_timeout, + alice, + alice, + ) +} + +func TokenFaucet(t *testing.T, tokenPath string, to std.Address) { + t.Helper() + std.TestSetOrigCaller(admin) + defaultAmount := uint64(5_000_000_000) + + switch tokenPath { + case wugnotPath: + wugnotTransfer(t, to, defaultAmount) + case gnsPath: + gnsTransfer(t, to, defaultAmount) + case barPath: + barTransfer(t, to, defaultAmount) + case bazPath: + bazTransfer(t, to, defaultAmount) + case fooPath: + fooTransfer(t, to, defaultAmount) + case oblPath: + oblTransfer(t, to, defaultAmount) + case quxPath: + quxTransfer(t, to, defaultAmount) + default: + panic("token not found") + } +} + +func TokenBalance(t *testing.T, tokenPath string, owner std.Address) uint64 { + t.Helper() + switch tokenPath { + case wugnotPath: + return wugnot.BalanceOf(common.AddrToUser(owner)) + case gnsPath: + return gns.BalanceOf(owner) + case barPath: + return bar.BalanceOf(owner) + case bazPath: + return baz.BalanceOf(owner) + case fooPath: + return foo.BalanceOf(owner) + case oblPath: + return obl.BalanceOf(owner) + case quxPath: + return qux.BalanceOf(owner) + default: + panic("token not found") + } +} + +func TokenAllowance(t *testing.T, tokenPath string, owner, spender std.Address) uint64 { + t.Helper() + switch tokenPath { + case wugnotPath: + return wugnot.Allowance(common.AddrToUser(owner), common.AddrToUser(spender)) + case gnsPath: + return gns.Allowance(owner, spender) + case barPath: + return bar.Allowance(owner, spender) + case bazPath: + return baz.Allowance(owner, spender) + case fooPath: + return foo.Allowance(owner, spender) + case oblPath: + return obl.Allowance(owner, spender) + case quxPath: + return qux.Allowance(owner, spender) + default: + panic("token not found") + } +} + +func TokenApprove(t *testing.T, tokenPath string, owner, spender std.Address, amount uint64) { + t.Helper() + switch tokenPath { + case wugnotPath: + wugnotApprove(t, owner, spender, amount) + case gnsPath: + gnsApprove(t, owner, spender, amount) + case barPath: + barApprove(t, owner, spender, amount) + case bazPath: + bazApprove(t, owner, spender, amount) + case fooPath: + fooApprove(t, owner, spender, amount) + case oblPath: + oblApprove(t, owner, spender, amount) + case quxPath: + quxApprove(t, owner, spender, amount) + default: + panic("token not found") + } +} + +func CreatePool(t *testing.T, + token0 string, + token1 string, + fee uint32, + sqrtPriceX96 string, + caller std.Address) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(caller)) + poolPath := pl.GetPoolPath(token0, token1, fee) + if !pl.DoesPoolPathExist(poolPath) { + pl.CreatePool(token0, token1, fee, sqrtPriceX96) + sr.SetPoolTierByAdmin(poolPath, TIER_1) + } +} + +func MintPosition(t *testing.T, + token0 string, + token1 string, + fee uint32, + tickLower int32, + tickUpper int32, + amount0Desired string, // *u256.Uint + amount1Desired string, // *u256.Uint + amount0Min string, // *u256.Uint + amount1Min string, // *u256.Uint + deadline int64, + mintTo std.Address, + caller std.Address, +) (uint64, string, string, string) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(caller)) + + return Mint( + token0, + token1, + fee, + tickLower, + tickUpper, + amount0Desired, + amount1Desired, + amount0Min, + amount1Min, + deadline, + mintTo, + caller) +} + +func MintPositionAll(t *testing.T, caller std.Address) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(caller)) + TokenApprove(t, gnsPath, caller, pool, maxApprove) + TokenApprove(t, gnsPath, caller, router, maxApprove) + TokenApprove(t, wugnotPath, caller, pool, maxApprove) + TokenApprove(t, wugnotPath, caller, router, maxApprove) + + params := []struct { + tickLower int32 + tickUpper int32 + liquidity uint64 + zeroToOne bool + }{ + { + tickLower: -300, + tickUpper: -240, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -240, + tickUpper: -180, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -180, + tickUpper: -120, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -120, + tickUpper: -60, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -60, + tickUpper: 0, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: 0, + tickUpper: 60, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 60, + tickUpper: 120, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 120, + tickUpper: 180, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 180, + tickUpper: 240, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 240, + tickUpper: 300, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: -360, + tickUpper: -300, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -420, + tickUpper: -360, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -480, + tickUpper: -420, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -540, + tickUpper: -480, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -600, + tickUpper: -540, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: 300, + tickUpper: 360, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 360, + tickUpper: 420, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 420, + tickUpper: 480, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 480, + tickUpper: 540, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 540, + tickUpper: 600, + liquidity: 10, + zeroToOne: false, + }, + } + + for _, p := range params { + MintPosition(t, + wugnotPath, + gnsPath, + fee3000, + p.tickLower, + p.tickUpper, + "100", + "100", + "0", + "0", + max_timeout, + caller, + caller) + } + +} + +func CreatePoolWithoutFee(t *testing.T) { + std.TestSetRealm(adminRealm) + // set pool create fee to 0 for testing + pl.SetPoolCreationFeeByAdmin(0) + CreatePool(t, barPath, fooPath, fee500, common.TickMathGetSqrtRatioAtTick(0).ToString(), admin) +} + +func MakeMintPositionWithoutFee(t *testing.T) (uint64, string, string, string) { + t.Helper() + + // make actual data to test resetting not only position's state but also pool's state + std.TestSetRealm(adminRealm) + + // set pool create fee to 0 for testing + pl.SetPoolCreationFeeByAdmin(0) + CreatePool(t, barPath, fooPath, fee500, common.TickMathGetSqrtRatioAtTick(0).ToString(), admin) + + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, fooPath, admin, pool, consts.UINT64_MAX) + + // mint position + return Mint( + barPath, + fooPath, + fee500, + -887270, + 887270, + "50000", + "50000", + "0", + "0", + max_timeout, + admin, + admin, + ) +} + +func LPTokenApprove(t *testing.T, owner, operator std.Address, tokenId uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + gnft.Approve(operator, tokenIdFrom(tokenId)) +} + +func LPTokenStake(t *testing.T, owner std.Address, tokenId uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + sr.StakeToken(tokenId) +} + +func LPTokenUnStake(t *testing.T, owner std.Address, tokenId uint64, unwrap bool) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + sr.UnstakeToken(tokenId, unwrap) +} + +func getPoolFromLpTokenId(t *testing.T, lpTokenId uint64) *pl.Pool { + t.Helper() + + position := MustGetPosition(lpTokenId) + return pl.GetPoolFromPoolPath(position.poolKey) +} + +func wugnotApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + wugnot.Approve(common.AddrToUser(spender), amount) +} + +func gnsApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + gns.Approve(spender, amount) +} + +func barApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + bar.Approve(spender, amount) +} + +func bazApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + baz.Approve(spender, amount) +} + +func fooApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + foo.Approve(spender, amount) +} + +func oblApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + obl.Approve(spender, amount) +} + +func quxApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + qux.Approve(spender, amount) +} + +func wugnotTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + wugnot.Transfer(common.AddrToUser(to), amount) +} + +func gnsTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + gns.Transfer(to, amount) +} + +func barTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + bar.Transfer(to, amount) +} + +func bazTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + baz.Transfer(to, amount) +} + +func fooTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + foo.Transfer(to, amount) +} + +func oblTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + obl.Transfer(to, amount) +} + +func quxTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + qux.Transfer(to, amount) +} + +// ---------------------------------------------------------------------------- +// ugnot + +func ugnotTransfer(t *testing.T, from, to std.Address, amount uint64) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(from)) + std.TestSetOrigSend(std.Coins{{ugnotDenom, int64(amount)}}, nil) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(from, to, std.Coins{{ugnotDenom, int64(amount)}}) +} + +func ugnotBalanceOf(t *testing.T, addr std.Address) uint64 { + t.Helper() + + banker := std.GetBanker(std.BankerTypeRealmIssue) + coins := banker.GetCoins(addr) + if len(coins) == 0 { + return 0 + } + + return uint64(coins.AmountOf(ugnotDenom)) +} + +func ugnotMint(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.IssueCoin(addr, denom, amount) + std.TestIssueCoins(addr, std.Coins{{denom, int64(amount)}}) +} + +func ugnotBurn(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.RemoveCoin(addr, denom, amount) +} + +func ugnotFaucet(t *testing.T, to std.Address, amount uint64) { + t.Helper() + faucetAddress := admin + std.TestSetOrigCaller(faucetAddress) + + if ugnotBalanceOf(t, faucetAddress) < amount { + newCoins := std.Coins{{ugnotDenom, int64(amount)}} + ugnotMint(t, faucetAddress, newCoins[0].Denom, newCoins[0].Amount) + std.TestSetOrigSend(newCoins, nil) + } + ugnotTransfer(t, faucetAddress, to, amount) +} + +func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(addr)) + wugnotAddr := consts.WUGNOT_ADDR + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(addr, wugnotAddr, std.Coins{{ugnotDenom, int64(amount)}}) + wugnot.Deposit() +} + +// resetObject resets the object state(clear or make it default values) +func resetObject(t *testing.T) { + positions = avl.NewTree() + nextId = 1 +} + +func burnTokens(t *testing.T) { + t.Helper() + + // burn tokens + for _, addr := range addrUsedInTest { + burnFoo(addr) + burnBar(addr) + burnBaz(addr) + burnQux(addr) + burnObl(addr) + burnUsdc(addr) + } +} + +func burnFoo(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + foo.Burn(addr, foo.BalanceOf(addr)) +} + +func burnBar(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + bar.Burn(addr, bar.BalanceOf(addr)) +} + +func burnBaz(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + baz.Burn(addr, baz.BalanceOf(addr)) +} + +func burnQux(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + qux.Burn(addr, qux.BalanceOf(addr)) +} + +func burnObl(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + obl.Burn(addr, obl.BalanceOf(addr)) +} + +func burnUsdc(addr std.Address) { + std.TestSetRealm(std.NewUserRealm(admin)) + usdc.Burn(addr, usdc.BalanceOf(addr)) +} + +// burnAllNFT burns all NFTs +func burnAllNFT(t *testing.T) { + t.Helper() + + std.TestSetRealm(std.NewCodeRealm(consts.POSITION_PATH)) + for i := uint64(1); i <= gnft.TotalSupply(); i++ { + gnft.Burn(tokenIdFrom(i)) + } +} + +func TestBeforeResetPositionObject(t *testing.T) { + t.Skip("only available for single file(current file) test") + // make actual data to test resetting not only position's state but also pool's state + std.TestSetRealm(adminRealm) + + tokenId, liquidity, amount0, amount1 := MakeMintPositionWithoutFee(t) + uassert.Equal(t, tokenId, uint64(1), "tokenId should be 1") + uassert.Equal(t, liquidity, "50000", "liquidity should be 50000") + uassert.Equal(t, amount0, "50000", "amount0 should be 50000") + uassert.Equal(t, amount1, "50000", "amount1 should be 50000") + uassert.Equal(t, positions.Size(), 1, "positions should have 1 position") + uassert.Equal(t, nextId, uint64(2), "nextId should be 2") + uassert.Equal(t, gnft.TotalSupply(), uint64(1), "gnft total supply should be 1") + uassert.Equal(t, pl.PoolGetLiquidity("gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500"), "50000", "pool liquidity should be 50000") +} + +func TestResetObject(t *testing.T) { + t.Skip("only available for single file(current file) test") + resetObject(t) + + uassert.Equal(t, positions.Size(), 0, "positions should be empty") + uassert.Equal(t, nextId, uint64(1), "nextId should be 1") +} + +func TestBurnTokens(t *testing.T) { + t.Skip("only available for single file(current file) test") + burnTokens(t) + + uassert.Equal(t, foo.BalanceOf(addr01), uint64(0)) // 100_000_000 -> 0 + uassert.Equal(t, bar.BalanceOf(addr01), uint64(0)) // 100_000_000 -> 0 +} + +func TestBurnAllNFT(t *testing.T) { + t.Skip("only available for single file(current file) test") + burnAllNFT(t) + uassert.Equal(t, gnft.TotalSupply(), uint64(0), "gnft total supply should be 0") +} diff --git a/contract/r/gnoswap/position/api.gno b/contract/r/gnoswap/position/api.gno new file mode 100644 index 000000000..7e87cf257 --- /dev/null +++ b/contract/r/gnoswap/position/api.gno @@ -0,0 +1,529 @@ +package position + +import ( + "std" + "time" + + "gno.land/p/demo/json" + ufmt "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + i256 "gno.land/p/gnoswap/int256" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/gnft" + pl "gno.land/r/gnoswap/v1/pool" +) + +type RpcPosition struct { + LpTokenId uint64 `json:"lpTokenId"` + Burned bool `json:"burned"` + Owner string `json:"owner"` + Operator string `json:"operator"` + PoolKey string `json:"poolKey"` + TickLower int32 `json:"tickLower"` + TickUpper int32 `json:"tickUpper"` + Liquidity string `json:"liquidity"` + FeeGrowthInside0LastX128 string `json:"feeGrowthInside0LastX128"` + FeeGrowthInside1LastX128 string `json:"feeGrowthInside1LastX128"` + TokensOwed0 string `json:"token0Owed"` + TokensOwed1 string `json:"token1Owed"` + + Token0Balance string `json:"token0Balance"` + Token1Balance string `json:"token1Balance"` + FeeUnclaimed0 string `json:"fee0Unclaimed"` + FeeUnclaimed1 string `json:"fee1Unclaimed"` +} + +type RpcUnclaimedFee struct { + LpTokenId uint64 `json:"lpTokenId"` + Fee0 string `json:"fee0"` + Fee1 string `json:"fee1"` +} + +type ResponseQueryBase struct { + Height int64 `json:"height"` + Timestamp int64 `json:"timestamp"` +} + +type ResponseApiGetPositions struct { + Stat ResponseQueryBase `json:"stat"` + Response []RpcPosition `json:"response"` +} + +func ApiGetPositions() string { + rpcPositions := []RpcPosition{} + + for tokenId := uint64(1); tokenId < nextId; tokenId++ { + rpcPosition := rpcMakePosition(tokenId) + rpcPositions = append(rpcPositions, rpcPosition) + } + + r := ResponseApiGetPositions{ + Stat: ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + }, + Response: rpcPositions, + } + + // STAT NODE + stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, position := range r.Response { + owner, err := gnft.OwnerOf(tokenIdFrom(position.LpTokenId)) + if err != nil { + panic(ufmt.Sprintf("owner not found for tokenId: %d", position.LpTokenId)) + } + + positionNode := json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(position.LpTokenId)), + "burned": json.BoolNode("burned", position.Burned), + "owner": json.StringNode("owner", owner.String()), + "operator": json.StringNode("operator", position.Operator), + "poolKey": json.StringNode("poolKey", position.PoolKey), + "tickLower": json.NumberNode("tickLower", float64(position.TickLower)), + "tickUpper": json.NumberNode("tickUpper", float64(position.TickUpper)), + "liquidity": json.StringNode("liquidity", position.Liquidity), + "feeGrowthInside0LastX128": json.StringNode("feeGrowthInside0LastX128", position.FeeGrowthInside0LastX128), + "feeGrowthInside1LastX128": json.StringNode("feeGrowthInside1LastX128", position.FeeGrowthInside1LastX128), + "token0Owed": json.StringNode("token0Owed", position.TokensOwed0), + "token1Owed": json.StringNode("token1Owed", position.TokensOwed1), + "token0Balance": json.StringNode("token0Balance", position.Token0Balance), + "token1Balance": json.StringNode("token1Balance", position.Token1Balance), + "fee0Unclaimed": json.StringNode("fee0Unclaimed", position.FeeUnclaimed0), + "fee1Unclaimed": json.StringNode("fee1Unclaimed", position.FeeUnclaimed1), + }) + responses.AppendArray(positionNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetPosition(lpTokenId uint64) string { + rpcPositions := []RpcPosition{} + + _, exist := GetPosition(lpTokenId) + if !exist { + return "" + } + + rpcPosition := rpcMakePosition(lpTokenId) + rpcPositions = append(rpcPositions, rpcPosition) + + r := ResponseApiGetPositions{ + Stat: ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + }, + Response: rpcPositions, + } + + // STAT NODE + stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, position := range r.Response { + owner, err := gnft.OwnerOf(tokenIdFrom(position.LpTokenId)) + if err != nil { + owner = consts.ZERO_ADDRESS + } + positionNode := json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(position.LpTokenId)), + "burned": json.BoolNode("burned", position.Burned), + "owner": json.StringNode("owner", owner.String()), + "operator": json.StringNode("operator", position.Operator), + "poolKey": json.StringNode("poolKey", position.PoolKey), + "tickLower": json.NumberNode("tickLower", float64(position.TickLower)), + "tickUpper": json.NumberNode("tickUpper", float64(position.TickUpper)), + "liquidity": json.StringNode("liquidity", position.Liquidity), + "feeGrowthInside0LastX128": json.StringNode("feeGrowthInside0LastX128", position.FeeGrowthInside0LastX128), + "feeGrowthInside1LastX128": json.StringNode("feeGrowthInside1LastX128", position.FeeGrowthInside1LastX128), + "token0Owed": json.StringNode("token0Owed", position.TokensOwed0), + "token1Owed": json.StringNode("token1Owed", position.TokensOwed1), + "token0Balance": json.StringNode("token0Balance", position.Token0Balance), + "token1Balance": json.StringNode("token1Balance", position.Token1Balance), + "fee0Unclaimed": json.StringNode("fee0Unclaimed", position.FeeUnclaimed0), + "fee1Unclaimed": json.StringNode("fee1Unclaimed", position.FeeUnclaimed1), + }) + responses.AppendArray(positionNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetPositionsByPoolPath(poolPath string) string { + rpcPositions := []RpcPosition{} + + for lpTokenId := uint64(1); lpTokenId < nextId; lpTokenId++ { + position := MustGetPosition(lpTokenId) + if position.poolKey != poolPath { + continue + } + + rpcPosition := rpcMakePosition(lpTokenId) + rpcPositions = append(rpcPositions, rpcPosition) + } + + r := ResponseApiGetPositions{ + Stat: ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + }, + Response: rpcPositions, + } + + // STAT NODE + stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, position := range r.Response { + owner, err := gnft.OwnerOf(tokenIdFrom(position.LpTokenId)) + if err != nil { + owner = consts.ZERO_ADDRESS + } + positionNode := json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(position.LpTokenId)), + "burned": json.BoolNode("burned", position.Burned), + "owner": json.StringNode("owner", owner.String()), + "operator": json.StringNode("operator", position.Operator), + "poolKey": json.StringNode("poolKey", position.PoolKey), + "tickLower": json.NumberNode("tickLower", float64(position.TickLower)), + "tickUpper": json.NumberNode("tickUpper", float64(position.TickUpper)), + "liquidity": json.StringNode("liquidity", position.Liquidity), + "feeGrowthInside0LastX128": json.StringNode("feeGrowthInside0LastX128", position.FeeGrowthInside0LastX128), + "feeGrowthInside1LastX128": json.StringNode("feeGrowthInside1LastX128", position.FeeGrowthInside1LastX128), + "token0Owed": json.StringNode("token0Owed", position.TokensOwed0), + "token1Owed": json.StringNode("token1Owed", position.TokensOwed1), + "token0Balance": json.StringNode("token0Balance", position.Token0Balance), + "token1Balance": json.StringNode("token1Balance", position.Token1Balance), + "fee0Unclaimed": json.StringNode("fee0Unclaimed", position.FeeUnclaimed0), + "fee1Unclaimed": json.StringNode("fee1Unclaimed", position.FeeUnclaimed1), + }) + responses.AppendArray(positionNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetPositionsByAddress(address std.Address) string { + rpcPositions := []RpcPosition{} + for lpTokenId := uint64(1); lpTokenId < nextId; lpTokenId++ { + position := MustGetPosition(lpTokenId) + owner, err := gnft.OwnerOf(tokenIdFrom(lpTokenId)) + if err != nil { + panic(ufmt.Sprintf("owner not found for tokenId: %d", lpTokenId)) + } + + if !(position.operator == address || owner == address) { + continue + } + + rpcPosition := rpcMakePosition(lpTokenId) + rpcPositions = append(rpcPositions, rpcPosition) + } + + r := ResponseApiGetPositions{ + Stat: ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + }, + Response: rpcPositions, + } + + // STAT NODE + stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, position := range r.Response { + owner, err := gnft.OwnerOf(tokenIdFrom(position.LpTokenId)) + if err != nil { + owner = consts.ZERO_ADDRESS + } + positionNode := json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(position.LpTokenId)), + "burned": json.BoolNode("burned", position.Burned), + "owner": json.StringNode("owner", owner.String()), + "operator": json.StringNode("operator", position.Operator), + "poolKey": json.StringNode("poolKey", position.PoolKey), + "tickLower": json.NumberNode("tickLower", float64(position.TickLower)), + "tickUpper": json.NumberNode("tickUpper", float64(position.TickUpper)), + "liquidity": json.StringNode("liquidity", position.Liquidity), + "feeGrowthInside0LastX128": json.StringNode("feeGrowthInside0LastX128", position.FeeGrowthInside0LastX128), + "feeGrowthInside1LastX128": json.StringNode("feeGrowthInside1LastX128", position.FeeGrowthInside1LastX128), + "token0Owed": json.StringNode("token0Owed", position.TokensOwed0), + "token1Owed": json.StringNode("token1Owed", position.TokensOwed1), + "token0Balance": json.StringNode("token0Balance", position.Token0Balance), + "token1Balance": json.StringNode("token1Balance", position.Token1Balance), + "fee0Unclaimed": json.StringNode("fee0Unclaimed", position.FeeUnclaimed0), + "fee1Unclaimed": json.StringNode("fee1Unclaimed", position.FeeUnclaimed1), + }) + responses.AppendArray(positionNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetPositionsUnclaimedFee() string { + rpcUnclaimedFee := []RpcUnclaimedFee{} + for lpTokenId := uint64(1); lpTokenId < nextId; lpTokenId++ { + unclaimedFee0, unclaimedFee1 := unclaimedFee(lpTokenId) + rpcUnclaimedFee = append(rpcUnclaimedFee, RpcUnclaimedFee{ + LpTokenId: lpTokenId, + Fee0: unclaimedFee0.ToString(), + Fee1: unclaimedFee1.ToString(), + }) + } + + // STAT NODE + stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, unclaimedFee := range rpcUnclaimedFee { + unclaimedFeeNode := json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(unclaimedFee.LpTokenId)), + "fee0": json.StringNode("fee0", unclaimedFee.Fee0), + "fee1": json.StringNode("fee1", unclaimedFee.Fee1), + }) + responses.AppendArray(unclaimedFeeNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetPositionUnclaimedFeeByLpTokenId(lpTokenId uint64) string { + rpcUnclaimedFee := []RpcUnclaimedFee{} + + unclaimedFee0, unclaimedFee1 := unclaimedFee(lpTokenId) + rpcUnclaimedFee = append(rpcUnclaimedFee, RpcUnclaimedFee{ + LpTokenId: lpTokenId, + Fee0: unclaimedFee0.ToString(), + Fee1: unclaimedFee1.ToString(), + }) + + // STAT NODE + stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, unclaimedFee := range rpcUnclaimedFee { + unclaimedFeeNode := json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(unclaimedFee.LpTokenId)), + "fee0": json.StringNode("fee0", unclaimedFee.Fee0), + "fee1": json.StringNode("fee1", unclaimedFee.Fee1), + }) + responses.AppendArray(unclaimedFeeNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func rpcMakePosition(lpTokenId uint64) RpcPosition { + position := MustGetPosition(lpTokenId) + + burned := isBurned(lpTokenId) + + pool := pl.GetPoolFromPoolPath(position.poolKey) + currentX96 := pool.Slot0SqrtPriceX96() + lowerX96 := common.TickMathGetSqrtRatioAtTick(position.tickLower) + upperX96 := common.TickMathGetSqrtRatioAtTick(position.tickUpper) + + token0Balance, token1Balance := common.GetAmountsForLiquidity( + currentX96, + lowerX96, + upperX96, + position.liquidity, + ) + + unclaimedFee0 := i256.Zero() + unclaimedFee1 := i256.Zero() + if !burned { + unclaimedFee0, unclaimedFee1 = unclaimedFee(lpTokenId) + } + + owner, err := gnft.OwnerOf(tokenIdFrom(lpTokenId)) + if err != nil { + owner = consts.ZERO_ADDRESS + } + + return RpcPosition{ + LpTokenId: lpTokenId, + Burned: burned, + Owner: owner.String(), + Operator: position.operator.String(), + PoolKey: position.poolKey, + TickLower: position.tickLower, + TickUpper: position.tickUpper, + Liquidity: position.liquidity.ToString(), + FeeGrowthInside0LastX128: position.feeGrowthInside0LastX128.ToString(), + FeeGrowthInside1LastX128: position.feeGrowthInside1LastX128.ToString(), + TokensOwed0: position.tokensOwed0.ToString(), + TokensOwed1: position.tokensOwed1.ToString(), + Token0Balance: token0Balance, + Token1Balance: token1Balance, + FeeUnclaimed0: unclaimedFee0.ToString(), + FeeUnclaimed1: unclaimedFee1.ToString(), + } +} + +func unclaimedFee(tokenId uint64) (*i256.Int, *i256.Int) { + // ref: https://blog.uniswap.org/uniswap-v3-math-primer-2#calculating-uncollected-fees + + position := MustGetPosition(tokenId) + + liquidityU256 := position.liquidity + liquidity := i256.FromUint256(liquidityU256) + + tickLower := position.tickLower + tickUpper := position.tickUpper + + poolKey := position.poolKey + pool := pl.GetPoolFromPoolPath(poolKey) + + currentTick := pool.Slot0Tick() + + feeGrowthGlobal0X128 := i256.FromUint256(pool.FeeGrowthGlobal0X128()) + feeGrowthGlobal1X128 := i256.FromUint256(pool.FeeGrowthGlobal1X128()) + + tickUpperFeeGrowthOutside0X128 := i256.FromUint256(pool.GetTickFeeGrowthOutside0X128(tickUpper)) + tickUpperFeeGrowthOutside1X128 := i256.FromUint256(pool.GetTickFeeGrowthOutside1X128(tickUpper)) + + tickLowerFeeGrowthOutside0X128 := i256.FromUint256(pool.GetTickFeeGrowthOutside0X128(tickLower)) + tickLowerFeeGrowthOutside1X128 := i256.FromUint256(pool.GetTickFeeGrowthOutside1X128(tickLower)) + + feeGrowthInside0LastX128 := i256.FromUint256(position.feeGrowthInside0LastX128) + feeGrowthInside1LastX128 := i256.FromUint256(position.feeGrowthInside1LastX128) + + var tickLowerFeeGrowthBelow0, tickLowerFeeGrowthBelow1, tickUpperFeeGrowthAbove0, tickUpperFeeGrowthAbove1 *i256.Int + + if currentTick >= tickUpper { + tickUpperFeeGrowthAbove0 = subIn256(feeGrowthGlobal0X128, tickUpperFeeGrowthOutside0X128) + tickUpperFeeGrowthAbove1 = subIn256(feeGrowthGlobal1X128, tickUpperFeeGrowthOutside1X128) + } else { + tickUpperFeeGrowthAbove0 = tickUpperFeeGrowthOutside0X128 + tickUpperFeeGrowthAbove1 = tickUpperFeeGrowthOutside1X128 + } + + if currentTick >= tickLower { + tickLowerFeeGrowthBelow0 = tickLowerFeeGrowthOutside0X128 + tickLowerFeeGrowthBelow1 = tickLowerFeeGrowthOutside1X128 + } else { + tickLowerFeeGrowthBelow0 = subIn256(feeGrowthGlobal0X128, tickLowerFeeGrowthOutside0X128) + tickLowerFeeGrowthBelow1 = subIn256(feeGrowthGlobal1X128, tickLowerFeeGrowthOutside1X128) + } + + feeGrowthInside0X128 := subIn256(feeGrowthGlobal0X128, tickLowerFeeGrowthBelow0) + feeGrowthInside0X128 = subIn256(feeGrowthInside0X128, tickUpperFeeGrowthAbove0) + + feeGrowthInside1X128 := subIn256(feeGrowthGlobal1X128, tickLowerFeeGrowthBelow1) + feeGrowthInside1X128 = subIn256(feeGrowthInside1X128, tickUpperFeeGrowthAbove1) + + value01 := subIn256(feeGrowthInside0X128, feeGrowthInside0LastX128) + value02 := i256.Zero().Mul(liquidity, value01) + unclaimedFee0 := i256.Zero().Div(value02, i256.MustFromDecimal(consts.Q128)) + + value11 := subIn256(feeGrowthInside1X128, feeGrowthInside1LastX128) + value12 := i256.Zero().Mul(liquidity, value11) + unclaimedFee1 := i256.Zero().Div(value12, i256.MustFromDecimal(consts.Q128)) + + return unclaimedFee0, unclaimedFee1 +} + +func subIn256(x, y *i256.Int) *i256.Int { + value := i256.Zero() + diff := value.Sub(x, y) + + if diff.IsNeg() { + q256 := i256.MustFromDecimal(consts.MAX_UINT256) + return diff.Add(diff, q256) + } + + return diff +} + +func isBurned(tokenId uint64) bool { + position := MustGetPosition(tokenId) + return position.burned +} diff --git a/contract/r/gnoswap/position/api_test.gno b/contract/r/gnoswap/position/api_test.gno new file mode 100644 index 000000000..fa15de44c --- /dev/null +++ b/contract/r/gnoswap/position/api_test.gno @@ -0,0 +1,455 @@ +package position + +import ( + "std" + "testing" + + "gno.land/p/demo/json" + "gno.land/p/demo/uassert" +) + +func setupPositions(t *testing.T) { + t.Helper() + + MakeMintPositionWithoutFee(t) +} + +func TestApiGetPositions(t *testing.T) { + setupPositions(t) + std.TestSetRealm(std.NewUserRealm(admin)) + + result := ApiGetPositions() + + root, err := json.Unmarshal([]byte(result)) + if err != nil { + panic(err.Error()) + } + + response, err := root.GetKey("response") + if err != nil { + panic(err.Error()) + } + uassert.NoError(t, err) + uassert.Equal(t, 1, response.Size()) + + response, err = response.GetIndex(0) + if err != nil { + panic(err.Error()) + } + + lpTokenId, err := response.GetKey("lpTokenId") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "1", lpTokenId.String()) + + burned, err := response.GetKey("burned") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "false", burned.String()) + + owner, err := response.GetKey("owner") + if err != nil { + panic(err.Error()) + } + ownerAddr := "\"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d\"" + uassert.Equal(t, ownerAddr, owner.String()) + + operator, err := response.GetKey("operator") + if err != nil { + panic(err.Error()) + } + operatorAddr := "\"g100000000000000000000000000000000dnmcnx\"" + uassert.Equal(t, operatorAddr, operator.String()) + + poolKey, err := response.GetKey("poolKey") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500\"", poolKey.String()) + + tickLower, err := response.GetKey("tickLower") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "-887270", tickLower.String()) + + tickUpper, err := response.GetKey("tickUpper") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "887270", tickUpper.String()) + + liquidity, err := response.GetKey("liquidity") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"50000\"", liquidity.String()) + + feeGrowthInside0LastX128, err := response.GetKey("feeGrowthInside0LastX128") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", feeGrowthInside0LastX128.String()) + + feeGrowthInside1LastX128, err := response.GetKey("feeGrowthInside1LastX128") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", feeGrowthInside1LastX128.String()) + + token0Owed, err := response.GetKey("token0Owed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", token0Owed.String()) + + token1Owed, err := response.GetKey("token1Owed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", token1Owed.String()) + + token0Balance, err := response.GetKey("token0Balance") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"49999\"", token0Balance.String()) + + token1Balance, err := response.GetKey("token1Balance") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"49999\"", token1Balance.String()) + + fee0Unclaimed, err := response.GetKey("fee0Unclaimed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee0Unclaimed.String()) + + fee1Unclaimed, err := response.GetKey("fee1Unclaimed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee1Unclaimed.String()) +} + +func TestApiGetPositionsByPoolPath(t *testing.T) { + result := ApiGetPositionsByPoolPath("gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500") + + root, err := json.Unmarshal([]byte(result)) + if err != nil { + panic(err.Error()) + } + + response, err := root.GetKey("response") + if err != nil { + panic(err.Error()) + } + uassert.NoError(t, err) + uassert.Equal(t, 1, response.Size()) + + response, err = response.GetIndex(0) + if err != nil { + panic(err.Error()) + } + + lpTokenId, err := response.GetKey("lpTokenId") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "1", lpTokenId.String()) + + burned, err := response.GetKey("burned") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "false", burned.String()) + + owner, err := response.GetKey("owner") + if err != nil { + panic(err.Error()) + } + ownerAddr := "\"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d\"" + uassert.Equal(t, ownerAddr, owner.String()) + + operator, err := response.GetKey("operator") + if err != nil { + panic(err.Error()) + } + operatorAddr := "\"g100000000000000000000000000000000dnmcnx\"" + uassert.Equal(t, operatorAddr, operator.String()) + + poolKey, err := response.GetKey("poolKey") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500\"", poolKey.String()) + + tickLower, err := response.GetKey("tickLower") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "-887270", tickLower.String()) + + tickUpper, err := response.GetKey("tickUpper") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "887270", tickUpper.String()) + + liquidity, err := response.GetKey("liquidity") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"50000\"", liquidity.String()) + + feeGrowthInside0LastX128, err := response.GetKey("feeGrowthInside0LastX128") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", feeGrowthInside0LastX128.String()) + + feeGrowthInside1LastX128, err := response.GetKey("feeGrowthInside1LastX128") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", feeGrowthInside1LastX128.String()) + + token0Owed, err := response.GetKey("token0Owed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", token0Owed.String()) + + token1Owed, err := response.GetKey("token1Owed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", token1Owed.String()) + + token0Balance, err := response.GetKey("token0Balance") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"49999\"", token0Balance.String()) + + token1Balance, err := response.GetKey("token1Balance") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"49999\"", token1Balance.String()) + + fee0Unclaimed, err := response.GetKey("fee0Unclaimed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee0Unclaimed.String()) + + fee1Unclaimed, err := response.GetKey("fee1Unclaimed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee1Unclaimed.String()) +} + +func TestApiGetPositionsByAddress(t *testing.T) { + address := admin + result := ApiGetPositionsByAddress(address) + + root, err := json.Unmarshal([]byte(result)) + if err != nil { + panic(err.Error()) + } + + response, err := root.GetKey("response") + if err != nil { + panic(err.Error()) + } + uassert.NoError(t, err) + uassert.Equal(t, 1, response.Size()) + + response, err = response.GetIndex(0) + if err != nil { + panic(err.Error()) + } + + lpTokenId, err := response.GetKey("lpTokenId") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "1", lpTokenId.String()) + + burned, err := response.GetKey("burned") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "false", burned.String()) + + owner, err := response.GetKey("owner") + if err != nil { + panic(err.Error()) + } + ownerAddr := "\"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d\"" + uassert.Equal(t, ownerAddr, owner.String()) + + operator, err := response.GetKey("operator") + if err != nil { + panic(err.Error()) + } + operatorAddr := "\"g100000000000000000000000000000000dnmcnx\"" + uassert.Equal(t, operatorAddr, operator.String()) + + poolKey, err := response.GetKey("poolKey") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500\"", poolKey.String()) + + tickLower, err := response.GetKey("tickLower") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "-887270", tickLower.String()) + + tickUpper, err := response.GetKey("tickUpper") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "887270", tickUpper.String()) + + liquidity, err := response.GetKey("liquidity") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"50000\"", liquidity.String()) + + feeGrowthInside0LastX128, err := response.GetKey("feeGrowthInside0LastX128") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", feeGrowthInside0LastX128.String()) + + feeGrowthInside1LastX128, err := response.GetKey("feeGrowthInside1LastX128") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", feeGrowthInside1LastX128.String()) + + token0Owed, err := response.GetKey("token0Owed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", token0Owed.String()) + + token1Owed, err := response.GetKey("token1Owed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", token1Owed.String()) + + token0Balance, err := response.GetKey("token0Balance") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"49999\"", token0Balance.String()) + + token1Balance, err := response.GetKey("token1Balance") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"49999\"", token1Balance.String()) + + fee0Unclaimed, err := response.GetKey("fee0Unclaimed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee0Unclaimed.String()) + + fee1Unclaimed, err := response.GetKey("fee1Unclaimed") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee1Unclaimed.String()) +} + +func TestApiGetPositionsUnclaimedFee(t *testing.T) { + result := ApiGetPositionsUnclaimedFee() + + root, err := json.Unmarshal([]byte(result)) + if err != nil { + panic(err.Error()) + } + + response, err := root.GetKey("response") + if err != nil { + panic(err.Error()) + } + uassert.NoError(t, err) + uassert.Equal(t, 1, response.Size()) + + response, err = response.GetIndex(0) + if err != nil { + panic(err.Error()) + } + + lpTokenId, err := response.GetKey("lpTokenId") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "1", lpTokenId.String()) + + fee0Unclaimed, err := response.GetKey("fee0") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee0Unclaimed.String()) + + fee1Unclaimed, err := response.GetKey("fee1") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee1Unclaimed.String()) +} + +func TestApiGetPositionUnclaimedFeeByLpTokenId(t *testing.T) { + setupPositions(t) + result := ApiGetPositionUnclaimedFeeByLpTokenId(1) + + root, err := json.Unmarshal([]byte(result)) + if err != nil { + panic(err.Error()) + } + + response, err := root.GetKey("response") + if err != nil { + panic(err.Error()) + } + uassert.NoError(t, err) + uassert.Equal(t, 1, response.Size()) + + response, err = response.GetIndex(0) + if err != nil { + panic(err.Error()) + } + + lpTokenId, err := response.GetKey("lpTokenId") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "1", lpTokenId.String()) + + fee0Unclaimed, err := response.GetKey("fee0") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee0Unclaimed.String()) + + fee1Unclaimed, err := response.GetKey("fee1") + if err != nil { + panic(err.Error()) + } + uassert.Equal(t, "\"0\"", fee1Unclaimed.String()) +} diff --git a/contract/r/gnoswap/position/doc.gno b/contract/r/gnoswap/position/doc.gno new file mode 100644 index 000000000..4073013b7 --- /dev/null +++ b/contract/r/gnoswap/position/doc.gno @@ -0,0 +1,2 @@ +// Package position manages user-specific assets and liquidity, allowing users to define their asset boundaries and manage their exposure. +package position diff --git a/contract/r/gnoswap/position/errors.gno b/contract/r/gnoswap/position/errors.gno new file mode 100644 index 000000000..4a2b8c086 --- /dev/null +++ b/contract/r/gnoswap/position/errors.gno @@ -0,0 +1,41 @@ +package position + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-POSITION-001] caller has no permission") + errSlippage = errors.New("[GNOSWAP-POSITION-002] slippage failed") + errWrapUnwrap = errors.New("[GNOSWAP-POSITION-003] wrap, unwrap failed") + errOutOfRange = errors.New("[GNOSWAP-POSITION-004] out of range for numeric value") + errInvalidInput = errors.New("[GNOSWAP-POSITION-005] invalid input data") + errDataNotFound = errors.New("[GNOSWAP-POSITION-006] requested data not found") + errExpired = errors.New("[GNOSWAP-POSITION-007] transaction expired") + errWugnotMinimum = errors.New("[GNOSWAP-POSITION-008] can not wrap less than minimum amount") + errNotClear = errors.New("[GNOSWAP-POSITION-009] position is not clear") + errZeroLiquidity = errors.New("[GNOSWAP-POSITION-010] zero liquidity") + errPositionExist = errors.New("[GNOSWAP-POSITION-011] position with same tokenId already exists") + errInvalidAddress = errors.New("[GNOSWAP-POSITION-012] invalid address") + errPositionDoesNotExist = errors.New("[GNOSWAP-POSITION-013] position does not exist") + errZeroUGNOT = errors.New("[GNOSWAP-POSITION-014] No UGNOTs were sent") + errInsufficientUGNOT = errors.New("[GNOSWAP-POSITION-015] Insufficient UGNOT provided") + errInvalidTokenPath = errors.New("[GNOSWAP-POSITION-016] invalid token address") + errInvalidLiquidityRatio = errors.New("[GNOSWAP-POSITION-017] invalid liquidity ratio") + errUnderflow = errors.New("[GNOSWAP-POSITION-018] underflow") + errInvalidLiquidity = errors.New("[GNOSWAP-POSITION-019] invalid liquidity") +) + +// newErrorWithDetail appends additional context or details to an existing error message. +// +// Parameters: +// - err: The original error (error). +// - detail: Additional context or detail to append to the error message (string). +// +// Returns: +// - string: The combined error message in the format " || ". +func newErrorWithDetail(err error, detail string) string { + return ufmt.Errorf("%s || %s", err.Error(), detail).Error() +} diff --git a/contract/r/gnoswap/position/getter.gno b/contract/r/gnoswap/position/getter.gno new file mode 100644 index 000000000..c4964e7a3 --- /dev/null +++ b/contract/r/gnoswap/position/getter.gno @@ -0,0 +1,118 @@ +package position + +import ( + "std" + + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/gnft" + pl "gno.land/r/gnoswap/v1/pool" +) + +func PositionGetPosition(tokenId uint64) Position { + position, _ := GetPosition(tokenId) + return position +} + +func PositionGetPositionNonce(tokenId uint64) *u256.Uint { + position := MustGetPosition(tokenId) + return position.nonce +} + +func PositionGetPositionOperator(tokenId uint64) std.Address { + position := MustGetPosition(tokenId) + return position.operator +} + +func PositionGetPositionPoolKey(tokenId uint64) string { + position := MustGetPosition(tokenId) + return position.poolKey +} + +func PositionGetPositionTickLower(tokenId uint64) int32 { + position := MustGetPosition(tokenId) + return position.tickLower +} + +func PositionGetPositionTickUpper(tokenId uint64) int32 { + position := MustGetPosition(tokenId) + return position.tickUpper +} + +func PositionGetPositionLiquidity(tokenId uint64) *u256.Uint { + position := MustGetPosition(tokenId) + return position.liquidity +} + +func PositionGetPositionFeeGrowthInside0LastX128(tokenId uint64) *u256.Uint { + position := MustGetPosition(tokenId) + return position.feeGrowthInside0LastX128 +} + +func PositionGetPositionFeeGrowthInside1LastX128(tokenId uint64) *u256.Uint { + position := MustGetPosition(tokenId) + return position.feeGrowthInside1LastX128 +} + +func PositionGetPositionTokensOwed0(tokenId uint64) *u256.Uint { + position := MustGetPosition(tokenId) + return position.tokensOwed0 +} + +func PositionGetPositionTokensOwed1(tokenId uint64) *u256.Uint { + position := MustGetPosition(tokenId) + return position.tokensOwed1 +} + +func PositionGetPositionIsBurned(tokenId uint64) bool { + position := MustGetPosition(tokenId) + return position.burned +} + +func PositionIsInRange(tokenId uint64) bool { + position := MustGetPosition(tokenId) + poolPath := position.poolKey + poolCurrentTick := pl.PoolGetSlot0Tick(poolPath) + + if position.tickLower <= poolCurrentTick && poolCurrentTick <= position.tickUpper { + return true + } + return false +} + +func PositionGetPositionOwner(tokenId uint64) std.Address { + owner, err := gnft.OwnerOf(tokenIdFrom(tokenId)) + if err != nil { + panic(newErrorWithDetail( + errDataNotFound, err.Error())) + } + return owner +} + +func PositionGetPositionNonceStr(tokenId uint64) string { + return MustGetPosition(tokenId).nonce.ToString() +} + +func PositionGetPositionOperatorStr(tokenId uint64) string { + return MustGetPosition(tokenId).operator.String() +} + +func PositionGetPositionLiquidityStr(tokenId uint64) string { + return MustGetPosition(tokenId).liquidity.ToString() +} + +func PositionGetPositionFeeGrowthInside0LastX128Str(tokenId uint64) string { + return MustGetPosition(tokenId).feeGrowthInside0LastX128.ToString() +} + +func PositionGetPositionFeeGrowthInside1LastX128Str(tokenId uint64) string { + return MustGetPosition(tokenId).feeGrowthInside1LastX128.ToString() +} + +func PositionGetPositionTokensOwed0Str(tokenId uint64) string { + return MustGetPosition(tokenId).tokensOwed0.ToString() +} + +func PositionGetPositionTokensOwed1Str(tokenId uint64) string { + return MustGetPosition(tokenId).tokensOwed1.ToString() +} diff --git a/contract/r/gnoswap/position/getter_test.gno b/contract/r/gnoswap/position/getter_test.gno new file mode 100644 index 000000000..9d3eb7936 --- /dev/null +++ b/contract/r/gnoswap/position/getter_test.gno @@ -0,0 +1,180 @@ +package position + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" + u256 "gno.land/p/gnoswap/uint256" +) + +func setupPositionGetter(t *testing.T) { + t.Helper() + + CreatePoolWithoutFee(t) + std.TestSetRealm(std.NewUserRealm(admin)) + position := Position{ + nonce: u256.Zero(), + operator: consts.POSITION_ADDR, + poolKey: "gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500", + tickLower: -10000, + tickUpper: 10000, + liquidity: u256.NewUint(1000000), + feeGrowthInside0LastX128: u256.Zero(), + feeGrowthInside1LastX128: u256.Zero(), + tokensOwed0: u256.Zero(), + tokensOwed1: u256.Zero(), + burned: false, + } + tokenId := getNextId() + setPosition(tokenId, position) +} + +func TestPositionGetPosition(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + position := PositionGetPosition(tokenId) + uassert.Equal(t, "gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500", position.poolKey) + uassert.Equal(t, int32(-10000), position.tickLower) + uassert.Equal(t, int32(10000), position.tickUpper) + uassert.Equal(t, "1000000", position.liquidity.ToString()) + uassert.Equal(t, "0", position.feeGrowthInside0LastX128.ToString()) + uassert.Equal(t, "0", position.feeGrowthInside1LastX128.ToString()) + uassert.Equal(t, "0", position.tokensOwed0.ToString()) + uassert.Equal(t, "0", position.tokensOwed1.ToString()) + uassert.Equal(t, false, position.burned) +} + +func TestPositionGetPositionNonce(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + nonce := PositionGetPositionNonce(tokenId) + uassert.Equal(t, "0", nonce.ToString()) +} + +func TestPositionGetPositionOperator(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + operator := PositionGetPositionOperator(tokenId) + uassert.Equal(t, consts.POSITION_ADDR, operator) +} + +func TestPositionGetPositionPoolKey(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + poolKey := PositionGetPositionPoolKey(tokenId) + uassert.Equal(t, "gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500", poolKey) +} + +func TestPositionGetPositionTickLower(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + tickLower := PositionGetPositionTickLower(tokenId) + uassert.Equal(t, int32(-10000), tickLower) +} + +func TestPositionGetPositionTickUpper(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + tickUpper := PositionGetPositionTickUpper(tokenId) + uassert.Equal(t, int32(10000), tickUpper) +} + +func TestPositionGetPositionLiquidity(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + liquidity := PositionGetPositionLiquidity(tokenId) + uassert.Equal(t, "1000000", liquidity.ToString()) +} + +func TestPositionGetPositionFeeGrowthInside0LastX128(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + feeGrowth0 := PositionGetPositionFeeGrowthInside0LastX128(tokenId) + uassert.Equal(t, "0", feeGrowth0.ToString()) +} + +func TestPositionGetPositionFeeGrowthInside1LastX128(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + feeGrowth1 := PositionGetPositionFeeGrowthInside1LastX128(tokenId) + uassert.Equal(t, "0", feeGrowth1.ToString()) +} + +func TestPositionGetPositionTokensOwed0(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + tokenOwed0 := PositionGetPositionTokensOwed0(tokenId) + uassert.Equal(t, "0", tokenOwed0.ToString()) +} + +func TestPositionGetPositionTokensOwed1(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + tokenOwed1 := PositionGetPositionTokensOwed1(tokenId) + uassert.Equal(t, "0", tokenOwed1.ToString()) +} + +func TestPositionGetPositionIsBurned(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + isBurn := PositionGetPositionIsBurned(tokenId) + uassert.Equal(t, false, isBurn) +} + +func TestPositionIsInRange(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + inRange := PositionIsInRange(tokenId) + uassert.Equal(t, true, inRange) +} + +func TestPositionGetPositionNonceStr(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + nonceStr := PositionGetPositionNonceStr(tokenId) + uassert.Equal(t, "0", nonceStr) +} + +func TestPositionGetPositionOperatorStr(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + operator := PositionGetPositionOperatorStr(tokenId) + uassert.Equal(t, string(consts.POSITION_ADDR), operator) +} + +func TestPositionGetPositionLiquidityStr(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + liquidity := PositionGetPositionLiquidityStr(tokenId) + uassert.Equal(t, "1000000", liquidity) +} + +func TestPositionGetPositionFeeGrowthInside0LastX128Str(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + feeGrowth0 := PositionGetPositionFeeGrowthInside0LastX128Str(tokenId) + uassert.Equal(t, "0", feeGrowth0) +} + +func TestPositionGetPositionFeeGrowthInside1LastX128Str(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + feeGrowth1 := PositionGetPositionFeeGrowthInside1LastX128Str(tokenId) + uassert.Equal(t, "0", feeGrowth1) +} + +func TestPositionGetPositionTokensOwed0Str(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + tokenOwed0 := PositionGetPositionTokensOwed0Str(tokenId) + uassert.Equal(t, "0", tokenOwed0) +} + +func TestPositionGetPositionTokensOwed1Str(t *testing.T) { + setupPositionGetter(t) + tokenId := getNextId() + tokenOwed1 := PositionGetPositionTokensOwed1Str(tokenId) + uassert.Equal(t, "0", tokenOwed1) +} diff --git a/contract/r/gnoswap/position/gno.mod b/contract/r/gnoswap/position/gno.mod new file mode 100644 index 000000000..7278afb2f --- /dev/null +++ b/contract/r/gnoswap/position/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/position diff --git a/contract/r/gnoswap/position/liquidity_management.gno b/contract/r/gnoswap/position/liquidity_management.gno new file mode 100644 index 000000000..6e902dd8a --- /dev/null +++ b/contract/r/gnoswap/position/liquidity_management.gno @@ -0,0 +1,118 @@ +package position + +import ( + "std" + + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + pl "gno.land/r/gnoswap/v1/pool" +) + +type AddLiquidityParams struct { + poolKey string // poolPath of the pool which has the position + tickLower int32 // lower end of the tick range for the position + tickUpper int32 // upper end of the tick range for the position + amount0Desired *u256.Uint // desired amount of token0 to be minted + amount1Desired *u256.Uint // desired amount of token1 to be minted + amount0Min *u256.Uint // minimum amount of token0 to be minted + amount1Min *u256.Uint // minimum amount of token1 to be minted + caller std.Address // address to call the function +} + +// addLiquidity calculates the liquidity to be added to a pool and mints the corresponding tokens. +// +// This function interacts with the specified pool to add liquidity for a given price range, specified by +// `tickLower` and `tickUpper`, and desired token amounts. It ensures that the resulting token amounts meet +// minimum thresholds to prevent excessive slippage. +// +// Parameters: +// - params (AddLiquidityParams): Contains the following fields: +// - poolKey: The unique identifier for the pool (string). +// - tickLower: The lower tick boundary of the liquidity range (int32). +// - tickUpper: The upper tick boundary of the liquidity range (int32). +// - amount0Desired: The desired amount of token0 to provide as liquidity (*u256.Uint). +// - amount1Desired: The desired amount of token1 to provide as liquidity (*u256.Uint). +// - amount0Min: The minimum acceptable amount of token0 to prevent slippage (*u256.Uint). +// - amount1Min: The minimum acceptable amount of token1 to prevent slippage (*u256.Uint). +// - caller: The address of the entity adding liquidity (std.Address). +// +// Returns: +// - *u256.Uint: The calculated liquidity amount to be added. +// - *u256.Uint: The actual amount of token0 used. +// - *u256.Uint: The actual amount of token1 used. +// +// Behavior: +// 1. Retrieves the pool information and current square root price (`sqrtPriceX96`). +// 2. Calculates the square root ratios (`sqrtRatioAX96` and `sqrtRatioBX96`) for the given tick boundaries. +// 3. Computes the liquidity to be added based on desired token amounts and the square root ratios. +// 4. Mints liquidity tokens to the position contract and determines the actual amounts of token0 and token1 used. +// 5. Ensures the actual token amounts used meet or exceed the specified minimum thresholds (`amount0Min` and `amount1Min`). +// - If the conditions are not met, the function panics with a slippage error. +// +// Panics: +// - If the actual token amounts used do not meet the minimum thresholds, a slippage error is raised. +// +// Notes: +// - The function relies on the `GetLiquidityForAmounts` function to calculate liquidity based on token amounts +// and price ratios. +// - Ensures the pool interactions use the caller-provided parameters to add liquidity safely. +// +// Example: +// +// liquidity, usedAmount0, usedAmount1 := addLiquidity(AddLiquidityParams{ +// poolKey: "gno.land/r/demo/pool", +// tickLower: -60000, +// tickUpper: 60000, +// amount0Desired: u256.MustFromDecimal("1000000000"), +// amount1Desired: u256.MustFromDecimal("2000000000"), +// amount0Min: u256.MustFromDecimal("950000000"), +// amount1Min: u256.MustFromDecimal("1900000000"), +// caller: userAddress, +// }) +func addLiquidity(params AddLiquidityParams) (*u256.Uint, *u256.Uint, *u256.Uint) { + pool := pl.GetPoolFromPoolPath(params.poolKey) + + sqrtPriceX96 := new(u256.Uint).Set(pool.Slot0SqrtPriceX96()) + sqrtRatioAX96 := common.TickMathGetSqrtRatioAtTick(params.tickLower) + sqrtRatioBX96 := common.TickMathGetSqrtRatioAtTick(params.tickUpper) + + liquidity := common.GetLiquidityForAmounts( + sqrtPriceX96, + sqrtRatioAX96, + sqrtRatioBX96, + params.amount0Desired, + params.amount1Desired, + ) + + token0, token1, fee := splitOf(params.poolKey) + amount0Str, amount1Str := pl.Mint( + token0, + token1, + fee, + consts.POSITION_ADDR, + params.tickLower, + params.tickUpper, + liquidity.ToString(), + params.caller, + ) + + amount0 := u256.MustFromDecimal(amount0Str) + amount1 := u256.MustFromDecimal(amount1Str) + + amount0Cond := amount0.Gte(params.amount0Min) + amount1Cond := amount1.Gte(params.amount1Min) + + if !(amount0Cond && amount1Cond) { + panic(newErrorWithDetail( + errSlippage, + ufmt.Sprintf( + "Price Slippage Check(amount0(%s) >= amount0Min(%s), amount1(%s) >= amount1Min(%s))", + amount0Str, params.amount0Min.ToString(), amount1Str, params.amount1Min.ToString()), + )) + } + + return liquidity, amount0, amount1 +} diff --git a/contract/r/gnoswap/position/liquidity_management_test.gno b/contract/r/gnoswap/position/liquidity_management_test.gno new file mode 100644 index 000000000..955988b9a --- /dev/null +++ b/contract/r/gnoswap/position/liquidity_management_test.gno @@ -0,0 +1,108 @@ +package position + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" + u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/common" + pl "gno.land/r/gnoswap/v1/pool" +) + +func InitialSetup(t *testing.T) { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + CreatePool(t, gnsPath, barPath, fee3000, common.TickMathGetSqrtRatioAtTick(0).ToString(), admin) + TokenFaucet(t, gnsPath, alice) + TokenFaucet(t, barPath, alice) +} + +func TestAddLiquidity(t *testing.T) { + poolKey := computePoolPath(gnsPath, barPath, fee3000) + + tests := []struct { + name string + params AddLiquidityParams + expectPanic bool + expectedAmount0 string + expectedAmount1 string + }{ + { + name: "Successful Liquidity Addition", + params: AddLiquidityParams{ + poolKey: poolKey, + tickLower: -600, + tickUpper: 600, + amount0Desired: u256.MustFromDecimal("1000000"), + amount1Desired: u256.MustFromDecimal("2000000"), + amount0Min: u256.MustFromDecimal("400000"), + amount1Min: u256.MustFromDecimal("800000"), + caller: alice, + }, + expectPanic: false, + expectedAmount0: "1000000", + expectedAmount1: "1000000", + }, + { + name: "Slippage Panic", + params: AddLiquidityParams{ + poolKey: poolKey, + tickLower: -600, + tickUpper: 600, + amount0Desired: u256.MustFromDecimal("1000000"), + amount1Desired: u256.MustFromDecimal("2000000"), + amount0Min: u256.MustFromDecimal("1100000"), + amount1Min: u256.MustFromDecimal("2200000"), + caller: alice, + }, + expectPanic: true, + }, + { + name: "Zero Liquidity", + params: AddLiquidityParams{ + poolKey: poolKey, + tickLower: -100, + tickUpper: 100, + amount0Desired: u256.Zero(), + amount1Desired: u256.Zero(), + amount0Min: u256.Zero(), + amount1Min: u256.Zero(), + caller: alice, + }, + expectPanic: true, + expectedAmount0: "0", + expectedAmount1: "0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.expectPanic { + t.Errorf("unexpected panic: %v", r) + } + } else { + if tt.expectPanic { + t.Errorf("expected panic but did not occur") + } + } + }() + + InitialSetup(t) + std.TestSetRealm(std.NewUserRealm(alice)) + TokenApprove(t, gnsPath, alice, consts.POOL_ADDR, tt.params.amount0Desired.Uint64()) + TokenApprove(t, barPath, alice, consts.POOL_ADDR, tt.params.amount1Desired.Uint64()) + + _, amount0, amount1 := addLiquidity(tt.params) + + if !tt.expectPanic { + uassert.Equal(t, tt.expectedAmount0, amount0.ToString()) + uassert.Equal(t, tt.expectedAmount1, amount1.ToString()) + } + }) + } +} diff --git a/contract/r/gnoswap/position/native_token.gno b/contract/r/gnoswap/position/native_token.gno new file mode 100644 index 000000000..a70118b79 --- /dev/null +++ b/contract/r/gnoswap/position/native_token.gno @@ -0,0 +1,172 @@ +package position + +import ( + "std" + "strconv" + + "gno.land/r/demo/wugnot" + + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +// wrap wraps the specified amount of the native token `ugnot` into the wrapped token `wugnot`. +// +// Parameters: +// - ugnotAmount (uint64): The amount of `ugnot` tokens to wrap into `wugnot`. +// - to (std.Address): The recipient's address to receive the wrapped tokens. +// +// Returns: +// - error: An error if the `ugnot` amount is zero, below the minimum wrapping threshold, or any other issue occurs. +// +// Example: +// +// wrap(1000, userAddress) +// - Wraps 1000 UGNOT into WUGNOT and transfers the WUGNOT to `userAddress`. +// +// Errors: +// - Returns an error if `ugnotAmount` is zero or less than the minimum deposit threshold. +func wrap(ugnotAmount uint64, to std.Address) error { + if ugnotAmount == 0 || ugnotAmount < consts.UGNOT_MIN_DEPOSIT_TO_WRAP { + return ufmt.Errorf("amount(%d) < minimum(%d)", ugnotAmount, consts.UGNOT_MIN_DEPOSIT_TO_WRAP) + } + + wugnotAddr := std.DerivePkgAddr(consts.WRAPPED_WUGNOT) + transferUGNOT(consts.POSITION_ADDR, wugnotAddr, ugnotAmount) + + wugnot.Deposit() // POSITION HAS WUGNOT + wugnot.Transfer(common.AddrToUser(to), ugnotAmount) // SEND WUGNOT: POSITION -> USER + + return nil +} + +// unwrap converts a specified amount of `WUGNOT` tokens into `UGNOT` tokens +// and transfers the resulting `UGNOT` back to the specified recipient address. +// +// Parameters: +// - `wugnotAmount`: The amount of `WUGNOT` tokens to unwrap (uint64). +// - `to`: The recipient's address (std.Address) to receive the unwrapped `UGNOT`. +// +// Example: +// unwrap(100, userAddress) +// - Converts 100 WUGNOT into UGNOT and sends the resulting UGNOT to `userAddress`. +func unwrap(wugnotAmount uint64, to std.Address) error { + if wugnotAmount == 0 { + return ufmt.Errorf("amount(%d) is zero", wugnotAmount) + } + + wugnot.TransferFrom(common.AddrToUser(to), common.AddrToUser(consts.POSITION_ADDR), wugnotAmount) // SEND WUGNOT: USER -> POSITION + wugnot.Withdraw(wugnotAmount) // POSITION HAS UGNOT + transferUGNOT(consts.POSITION_ADDR, to, wugnotAmount) // SEND UGNOT: POSITION -> USER + return nil +} + +// transferUGNOT transfers a specified amount of `UGNOT` tokens from one address to another. +// The function ensures that no transaction occurs if the transfer amount is zero. +// It uses the `std.BankerTypeRealmSend` banker type to facilitate the transfer. +// +// Parameters: +// - `from`: The sender's address (std.Address). +// - `to`: The recipient's address (std.Address). +// - `amount`: The amount of UGNOT tokens to transfer (uint64). +// +// Example: +// transferUGNOT(sender, receiver, 100) // Transfers 100 UGNOT from `sender` to `receiver`. +func transferUGNOT(from, to std.Address, amount uint64) { + if amount == 0 { + return + } + + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(from, to, std.Coins{ + {Denom: consts.UGNOT, Amount: int64(amount)}, + }) +} + +// refundUGNOT refunds a specified amount of `UGNOT` tokens to the provided address. +// This function uses `transferUGNOT` to perform the transfer from the contract's position address +// (`POSITION_ADDR`) to the recipient. +// +// Parameters: +// - `to`: The recipient's address (std.Address) who will receive the refund. +// - `amount`: The amount of `UGNOT` tokens to refund (uint64). +func refundUGNOT(to std.Address, amount uint64) { + transferUGNOT(consts.POSITION_ADDR, to, amount) +} + +// isNative checks whether the given token is a native token. +func isNative(token string) bool { + return token == consts.GNOT +} + +// isWrappedToken checks whether the tokenPath is wrapped token +func isWrappedToken(tokenPath string) bool { + return tokenPath == consts.WRAPPED_WUGNOT +} + +// safeWrapNativeToken safely wraps the native token `ugnot` into the wrapped token `wugnot` for a user. +// +// Parameters: +// - amountDesired: The desired amount of `ugnot` to be wrapped, provided as a string. +// - userAddress: The address of the user initiating the wrapping process. +// +// Returns: +// - uint64: The amount of `ugnot` that was successfully wrapped into `wugnot`. +// +// Panics: +// - If the sent `ugnot` amount is zero. +// - If `amountDesired` cannot be parsed into a valid uint64 value. +// - If the sent `ugnot` amount is less than `amountDesired`. +// - If the `wrap` function fails to wrap the tokens. +// - If there is a mismatch between the expected wrapped token amount and the user's balance after wrapping. +func safeWrapNativeToken(amountDesired string, userAddress std.Address) uint64 { + beforeWugnotBalance := wugnot.BalanceOf(common.AddrToUser(userAddress)) + sentNative := std.GetOrigSend() + sentUgnotAmount := uint64(sentNative.AmountOf(consts.UGNOT)) + + if sentUgnotAmount <= 0 { + panic(newErrorWithDetail(errZeroUGNOT, "amount of ugnot is zero")) + } + + amount, err := strconv.ParseUint(amountDesired, 10, 64) + if err != nil { + panic(newErrorWithDetail(errWrapUnwrap, err.Error())) + } + + if sentUgnotAmount < amount { + panic(newErrorWithDetail(errInsufficientUGNOT, "amount of ugnot is less than desired amount")) + } + + if sentUgnotAmount > amount { + exceed := sentUgnotAmount - amount + refundUGNOT(userAddress, exceed) + + sentUgnotAmount = amount + } + + if err = wrap(sentUgnotAmount, userAddress); err != nil { + panic(newErrorWithDetail(errWugnotMinimum, err.Error())) + } + + afterWugnotBalance := wugnot.BalanceOf(common.AddrToUser(userAddress)) + diff := afterWugnotBalance - beforeWugnotBalance + + if diff != sentUgnotAmount { + panic(newErrorWithDetail( + errWrapUnwrap, + ufmt.Sprintf("amount of ugnot (%d) is not equal to amount of wugnot. (diff: %d)", sentUgnotAmount, diff), + )) + } + return sentUgnotAmount +} + +func handleUnwrap(pToken0, pToken1 string, unwrapResult bool, userOldWugnotBalance uint64, to std.Address) { + if (pToken0 == consts.WRAPPED_WUGNOT || pToken1 == consts.WRAPPED_WUGNOT) && unwrapResult { + userNewWugnotBalance := wugnot.BalanceOf(common.AddrToUser(to)) + leftOver := userNewWugnotBalance - userOldWugnotBalance + if leftOver > 0 { + unwrap(leftOver, to) + } + } +} diff --git a/contract/r/gnoswap/position/native_token_test.gno b/contract/r/gnoswap/position/native_token_test.gno new file mode 100644 index 000000000..badd2df77 --- /dev/null +++ b/contract/r/gnoswap/position/native_token_test.gno @@ -0,0 +1,429 @@ +package position + +import ( + "std" + "strconv" + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" +) + +func TestTransferUGNOT(t *testing.T) { + tests := []struct { + name string + action func(t *testing.T, from, to std.Address) + verify func(t *testing.T, to std.Address) uint64 + from std.Address + to std.Address + expected string + shouldPanic bool + }{ + { + name: "Success - Zero amount", + action: func(t *testing.T, from, to std.Address) { + transferUGNOT(from, to, 0) + }, + verify: func(t *testing.T, to std.Address) uint64 { + return ugnotBalanceOf(t, to) + }, + from: alice, + to: bob, + expected: "0", + shouldPanic: false, + }, + { + name: "Success - Valid transfer", + action: func(t *testing.T, from, to std.Address) { + ugnotFaucet(t, from, 100) + std.TestSetRealm(std.NewUserRealm(from)) + transferUGNOT(from, to, 100) + }, + verify: func(t *testing.T, to std.Address) uint64 { + return ugnotBalanceOf(t, to) + }, + from: consts.POSITION_ADDR, + to: bob, + expected: "100", + shouldPanic: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + tc.action(t, tc.from, tc.to) + if tc.verify != nil { + balance := tc.verify(t, tc.to) + uassert.Equal(t, tc.expected, strconv.FormatUint(balance, 10)) + } + } else { + tc.action(t, tc.from, tc.to) + } + }) + } +} + +func TestWrapInPosition(t *testing.T) { + tests := []struct { + name string + action func(t *testing.T, from, to std.Address) error + verify func(t *testing.T, to std.Address) uint64 + from std.Address + to std.Address + expected string + shouldPanic bool + }{ + { + name: "Failure - Amount less than minimum", + action: func(t *testing.T, from, to std.Address) error { + return wrap(999, to) + }, + verify: nil, + from: alice, + to: bob, + expected: "amount(999) < minimum(1000)", + shouldPanic: true, + }, + { + name: "Failure - Zero amount", + action: func(t *testing.T, from, to std.Address) error { + return wrap(0, to) + }, + verify: nil, + from: alice, + to: bob, + expected: "amount(0) < minimum(1000)", + shouldPanic: true, + }, + { + name: "Success - Valid amount", + action: func(t *testing.T, from, to std.Address) error { + ugnotFaucet(t, from, 1000) + std.TestSetRealm(std.NewUserRealm(from)) + return wrap(1000, to) + }, + verify: func(t *testing.T, to std.Address) uint64 { + return TokenBalance(t, wugnotPath, to) + }, + from: consts.POSITION_ADDR, + to: bob, + expected: "1000", + shouldPanic: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r != nil { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + err := tc.action(t, tc.from, tc.to) + if err == nil && tc.verify != nil { + balance := tc.verify(t, tc.to) + uassert.Equal(t, tc.expected, strconv.FormatUint(balance, 10)) + } + } else { + err := tc.action(t, tc.from, tc.to) + if err != nil { + uassert.Equal(t, tc.expected, err.Error()) + } else { + t.Errorf(">>> %s: expected panic but got none", tc.name) + } + } + }) + } +} + +func TestUnWrap(t *testing.T) { + tests := []struct { + name string + action func(t *testing.T, from, to std.Address) error + verify func(t *testing.T, to std.Address) uint64 + from std.Address + to std.Address + expected string + shouldPanic bool + }{ + { + name: "Failure - Zero amount", + action: func(t *testing.T, from, to std.Address) error { + return unwrap(0, to) + }, + verify: nil, + from: alice, + to: bob, + expected: "amount(0) is zero", + shouldPanic: true, + }, + { + name: "Success - Valid amount", + action: func(t *testing.T, from, to std.Address) error { + ugnotFaucet(t, from, 1000) + std.TestSetRealm(std.NewUserRealm(from)) + wrap(1000, to) + std.TestSetRealm(std.NewUserRealm(to)) + TokenApprove(t, wugnotPath, to, from, 1000) + return unwrap(1000, to) + }, + verify: func(t *testing.T, to std.Address) uint64 { + return TokenBalance(t, wugnotPath, to) + }, + from: consts.POSITION_ADDR, + to: bob, + expected: "1000", + shouldPanic: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r != nil { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + err := tc.action(t, tc.from, tc.to) + if err == nil && tc.verify != nil { + balance := tc.verify(t, tc.to) + uassert.Equal(t, tc.expected, strconv.FormatUint(balance, 10)) + } + } else { + err := tc.action(t, tc.from, tc.to) + if err != nil { + uassert.Equal(t, tc.expected, err.Error()) + } else { + t.Errorf(">>> %s: expected panic but got none", tc.name) + } + } + }) + } +} + +func TestIsNative(t *testing.T) { + tests := []struct { + name string + token string + expected bool + }{ + { + name: "Native Token - GNOT", + token: "gnot", + expected: true, + }, + { + name: "Non-Native Token", + token: "usdt", + expected: false, + }, + { + name: "Empty Token", + token: "", + expected: false, + }, + { + name: "Similar but Different Token", + token: "GNOT", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isNative(tt.token) + uassert.Equal(t, tt.expected, result, "Unexpected result for token: "+tt.token) + }) + } +} + +func TestIsWrappedToken(t *testing.T) { + tests := []struct { + name string + tokenPath string + expected bool + }{ + { + name: "Success - Token is Wrapped WUGNOT", + tokenPath: consts.WRAPPED_WUGNOT, + expected: true, + }, + { + name: "Fail - Token is not Wrapped WUGNOT", + tokenPath: "gno.land/r/demo/ugnot", + expected: false, + }, + { + name: "Fail - Empty tokenPath", + tokenPath: "", + expected: false, + }, + { + name: "Fail - Similar but Different Token Path", + tokenPath: "gno.land/r/demo/Wugnot", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isWrappedToken(tt.tokenPath) + if result != tt.expected { + t.Errorf( + "expected %s but got %s", + strconv.FormatBool(tt.expected), + strconv.FormatBool(result), + ) + } + }) + } +} + +func TestSafeWrapNativeToken(t *testing.T) { + tests := []struct { + name string + amountDesired string + userAddress std.Address + sentAmount uint64 + expectPanic bool + expectedWrap uint64 + }{ + { + name: "Panic - Zero UGNOT", + amountDesired: "50", + userAddress: alice, + sentAmount: 0, + expectPanic: true, + }, + { + name: "Panic - Insufficient UGNOT", + amountDesired: "150", + userAddress: alice, + sentAmount: 100, + expectPanic: true, + }, + { + name: "Panic - Invalid Desired Amount", + amountDesired: "invalid", + userAddress: alice, + sentAmount: 200, + expectPanic: true, + }, + { + name: "Successful wrap - Exact Amount", + amountDesired: "1050", + userAddress: alice, + sentAmount: 1050, + expectPanic: false, + expectedWrap: 1050, + }, + { + name: "Excess Refund", + amountDesired: "1000", + userAddress: alice, + sentAmount: 1500, + expectPanic: false, + expectedWrap: 1000, + }, + { + name: "Boundary Test - Exact Match", + amountDesired: "1000", + userAddress: alice, + sentAmount: 1000, + expectPanic: false, + expectedWrap: 1000, + }, + { + name: "Zero Desired Amount", + amountDesired: "0", + userAddress: alice, + sentAmount: 100, + expectPanic: true, + }, + { + name: "Wrap Error Test", + amountDesired: "100", + userAddress: alice, + sentAmount: 100, + expectPanic: true, // Simulate wrap error internally + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.expectPanic { + t.Errorf("unexpected panic: %v", r) + } + } else { + if tt.expectPanic { + t.Errorf("expected panic but did not occur") + } + } + }() + + amount, _ := strconv.ParseUint(tt.amountDesired, 10, 64) + ugnotFaucet(t, consts.POSITION_ADDR, amount) + std.TestSetRealm(std.NewUserRealm(consts.POSITION_ADDR)) + transferUGNOT(consts.POSITION_ADDR, consts.POSITION_ADDR, amount) + + // Perform wrapping + wrappedAmount := safeWrapNativeToken(tt.amountDesired, tt.userAddress) + + // Verify wrapped amount + if !tt.expectPanic { + uassert.Equal(t, tt.expectedWrap, wrappedAmount) + } + }) + } +} diff --git a/contract/r/gnoswap/position/position.gno b/contract/r/gnoswap/position/position.gno new file mode 100644 index 000000000..f70c429ea --- /dev/null +++ b/contract/r/gnoswap/position/position.gno @@ -0,0 +1,1192 @@ +package position + +import ( + "encoding/base64" + "std" + "strconv" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gnft" + + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" + "gno.land/p/gnoswap/consts" + + en "gno.land/r/gnoswap/v1/emission" + pl "gno.land/r/gnoswap/v1/pool" +) + +const ( + ZERO_LIQUIDITY_FOR_FEE_COLLECTION = "0" +) + +var ( + positions = avl.NewTree() // tokenId[uint64] -> Position + nextId = uint64(1) +) + +// MustGetPosition returns a position for a given tokenId +// panics if position doesn't exist +func MustGetPosition(tokenId uint64) Position { + position, exist := GetPosition(tokenId) + if !exist { + panic(newErrorWithDetail( + errPositionDoesNotExist, + ufmt.Sprintf("position with tokenId(%d) doesn't exist", tokenId), + )) + } + return position +} + +// mustUpdatePosition updates a position for a given tokenId +func mustUpdatePosition(tokenId uint64, position Position) { + update := setPosition(tokenId, position) + if !update { + panic(newErrorWithDetail( + errPositionDoesNotExist, + ufmt.Sprintf("position with tokenId(%d) doesn't exist", tokenId), + )) + } +} + +// GetPosition returns a position for a given tokenId +// Returns false if position doesn't exist +func GetPosition(tokenId uint64) (Position, bool) { + tokenIdStr := strconv.FormatUint(tokenId, 10) + iPosition, exist := positions.Get(tokenIdStr) + if !exist { + return Position{}, false + } + + return iPosition.(Position), true +} + +// setPosition sets a position for a given tokenId +// Returns true if position is newly created, false if position already exists and just updated. +func setPosition(tokenId uint64, position Position) bool { + tokenIdStr := strconv.FormatUint(tokenId, 10) + return positions.Set(tokenIdStr, position) +} + +// removePosition removes a position for a given tokenId +func removePosition(tokenId uint64) { + tokenIdStr := strconv.FormatUint(tokenId, 10) + positions.Remove(tokenIdStr) +} + +func existPosition(tokenId uint64) bool { + _, exist := GetPosition(tokenId) + return exist +} + +// computePositionKey generates a unique base64-encoded key for a liquidity position. +// +// This function takes an owner's address and the lower and upper tick bounds of a position, +// and generates a unique key by concatenating the parameters into a string. The resulting +// string is base64 encoded to ensure it is compact and unique. +// +// Parameters: +// - owner (std.Address): The address of the position owner. +// - tickLower (int32): The lower tick boundary of the position. +// - tickUpper (int32): The upper tick boundary of the position. +// +// Returns: +// - string: A base64-encoded string representing the unique key for the position. +// +// Notes: +// - This function is useful in scenarios where unique identifiers for liquidity positions +// are required (e.g., decentralized exchange positions). +// - The key format follows the pattern "ownerAddress__tickLower__tickUpper" to ensure uniqueness. +func computePositionKey( + owner std.Address, + tickLower int32, + tickUpper int32, +) string { + key := ufmt.Sprintf("%s__%d__%d", owner.String(), tickLower, tickUpper) + encoded := base64.StdEncoding.EncodeToString([]byte(key)) + return encoded +} + +// nextId is the next tokenId to be minted +func getNextId() uint64 { + return nextId +} + +// incrementNextId increments the next tokenId to be minted +func incrementNextId() { + nextId++ +} + +// Mint creates a new liquidity position by depositing token pairs into the pool and minting a new LP token. +// +// Parameters: +// - token0: The address of token0. +// - token1: The address of token1. +// - fee: The fee tier of the pool, in basis points. +// - tickLower: The lower tick boundary of the position. +// - tickUpper: The upper tick boundary of the position. +// - amount0Desired: Desired amount of token0 to add as liquidity, as a string. +// - amount1Desired: Desired amount of token1 to add as liquidity, as a string. +// - amount0Min: Minimum acceptable amount of token0 to add as liquidity, as a string. +// - amount1Min: Minimum acceptable amount of token1 to add as liquidity, as a string. +// - deadline: Expiration timestamp for the transaction. +// - mintTo: Address to receive the minted LP token. +// - caller: The address of the entity (contract or user) providing liquidity; assets will be withdrawn from this address. +// +// Returns: +// - uint64: The ID of the newly minted liquidity position. +// - string: The amount of liquidity provided to the position. +// - string: The amount of token0 used in the mint. +// - string: The amount of token1 used in the mint. +// +// Behavior: +// 1. **Validation**: +// - Ensures the contract is not halted. +// - Validates that the caller is either a user or a staker contract. +// - If the caller is a user, validates the `mintTo` and `caller` addresses to ensure they match. +// - Checks the transaction's deadline to prevent expired transactions. +// 2. **Pre-Mint Setup**: +// - Calls `MintAndDistributeGns` to handle GNS emissions. +// - Processes the input parameters for minting (`processMintInput`) to standardize and validate the inputs. +// 3. **Mint Execution**: +// - Executes the mint operation using the processed parameters. +// - Withdraws the required token amounts (`token0` and `token1`) from the `caller` address. +// - Mints a new LP token, and the resulting LP token is sent to the `mintTo` address. +// 4. **Post-Mint Cleanup**: +// - If native tokens were used (e.g., `ugnot`), unwraps any leftover wrapped tokens (`wugnot`) and refunds them to the `caller` address. +// 5. **Event Emission**: +// - Emits a "Mint" event containing detailed information about the mint operation. +// +// Panics: +// - If the contract is halted. +// - If the caller is not authorized. +// - If the transaction deadline has passed. +// - If input validation fails. +// - If errors occur during the minting process or leftover token unwrapping. +// +// ref: https://docs.gnoswap.io/contracts/position/position.gno#mint +func Mint( + token0 string, + token1 string, + fee uint32, + tickLower int32, + tickUpper int32, + amount0Desired string, + amount1Desired string, + amount0Min string, + amount1Min string, + deadline int64, + mintTo std.Address, + caller std.Address, +) (uint64, string, string, string) { + assertOnlyNotHalted() + prevCaller := getPrevRealm() + assertOnlyUserOrStaker(prevCaller) + if isUserCall() { + // if user called, validate the prev address and input addresses(mintTo, caller) + assertOnlyValidAddressWith(prevCaller.Addr(), mintTo) + assertOnlyValidAddressWith(prevCaller.Addr(), caller) + } + checkDeadline(deadline) + + en.MintAndDistributeGns() + + mintInput := MintInput{ + token0: token0, + token1: token1, + fee: fee, + tickLower: tickLower, + tickUpper: tickUpper, + amount0Desired: amount0Desired, + amount1Desired: amount1Desired, + amount0Min: amount0Min, + amount1Min: amount1Min, + deadline: deadline, + mintTo: mintTo, + caller: caller, + } + + processedInput, err := processMintInput(mintInput) + if err != nil { + panic(newErrorWithDetail(errInvalidInput, err.Error())) + } + + mintParams := newMintParams(processedInput, mintInput) + tokenId, liquidity, amount0, amount1 := mint(mintParams) + if processedInput.tokenPair.token0IsNative && processedInput.tokenPair.wrappedAmount > amount0.Uint64() { + // unwrap leftover wugnot + err = unwrap(processedInput.tokenPair.wrappedAmount-amount0.Uint64(), caller) + if err != nil { + panic(newErrorWithDetail(errWrapUnwrap, err.Error())) + } + } + if processedInput.tokenPair.token1IsNative && processedInput.tokenPair.wrappedAmount > amount1.Uint64() { + // unwrap leftover wugnot + err = unwrap(processedInput.tokenPair.wrappedAmount-amount1.Uint64(), caller) + if err != nil { + panic(newErrorWithDetail(errWrapUnwrap, err.Error())) + } + } + poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(processedInput.poolPath) + + prevAddr, prevPkgPath := getPrevAsString() + std.Emit( + "Mint", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "tickLower", formatInt(int64(processedInput.tickLower)), + "tickUpper", formatInt(int64(processedInput.tickUpper)), + "poolPath", processedInput.poolPath, + "mintTo", mintTo.String(), + "caller", caller.String(), + "lpTokenId", formatUint(tokenId), + "liquidity", liquidity.ToString(), + "amount0", amount0.ToString(), + "amount1", amount1.ToString(), + "sqrtPriceX96", poolSqrtPriceX96, + ) + + return tokenId, liquidity.ToString(), amount0.ToString(), amount1.ToString() +} + +// processMintInput processes and validates user input for minting liquidity. +// +// This function standardizes and verifies minting parameters by ensuring token order, parsing desired amounts, +// and handling native token wrapping. It returns a structured `ProcessedMintInput` that can be used for further minting operations. +// +// Parameters: +// - input (MintInput): Raw user input containing token addresses, tick bounds, and liquidity amounts. +// +// Returns: +// - ProcessedMintInput: A structured and validated version of the minting input. +// - error: Returns an error if input parsing, token processing, or number validation fails. +// +// Behavior: +// +// 1. **Number Validation**: +// - Validates `amount0Desired`, `amount1Desired`, `amount0Min`, and `amount1Min` to ensure they are valid numeric strings. +// - If validation fails, the function panics immediately with an error indicating invalid input. +// +// 2. **Token Processing**: +// - Calls `processTokens` to validate and convert tokens (`token0` and `token1`) into their final forms. +// - Handles wrapping of native tokens (e.g., UGNOT to WUGNOT) if necessary. +// - Stores token metadata in a `TokenPair` struct, including wrapped amounts and native token status. +// +// 3. **Amount Parsing**: +// - Converts `amount0Desired`, `amount1Desired`, `amount0Min`, and `amount1Min` from string to `u256.Uint` using `parseAmounts`. +// - Ensures accurate representation of liquidity amounts for further processing. +// +// 4. **Token Order Enforcement**: +// - If `token1` is lexicographically smaller than `token0`, the function swaps their order to enforce consistent pool identification. +// - Along with token swaps, the tick bounds (`tickLower`, `tickUpper`) are inverted to preserve correct price boundaries. +// - This step guarantees pool uniqueness by ensuring `token0 < token1`. +// +// 5. **Pool Path Calculation**: +// - Computes the pool path (`poolPath`) using the finalized token addresses and fee tier. +// - The pool path uniquely identifies the pool in which liquidity will be minted. +// +// 6. **Return**: +// - Returns a populated `ProcessedMintInput` struct containing the finalized minting parameters. +// +// Panics: +// - If any of the provided amount strings are invalid (non-numeric or empty). +// - If `processTokens` encounters errors during token validation or wrapping. +// +// Notes: +// - This function enforces token order and validates amounts to ensure the integrity of liquidity minting operations. +func processMintInput(input MintInput) (ProcessedMintInput, error) { + assertValidNumberString(input.amount0Desired) + assertValidNumberString(input.amount1Desired) + assertValidNumberString(input.amount0Min) + assertValidNumberString(input.amount1Min) + var result ProcessedMintInput + + // process tokens + token0, token1, token0IsNative, token1IsNative, wrappedAmount := processTokens(input.token0, input.token1, input.amount0Desired, input.amount1Desired, input.caller) + pair := TokenPair{ + token0: token0, + token1: token1, + token0IsNative: token0IsNative, + token1IsNative: token1IsNative, + wrappedAmount: wrappedAmount, + } + + // parse amounts + amount0Desired, amount1Desired, amount0Min, amount1Min := parseAmounts(input.amount0Desired, input.amount1Desired, input.amount0Min, input.amount1Min) + + tickLower, tickUpper := input.tickLower, input.tickUpper + + // swap if token1 < token0 + if token1 < token0 { + pair.token0, pair.token1 = pair.token1, pair.token0 + amount0Desired, amount1Desired = amount1Desired, amount0Desired + amount0Min, amount1Min = amount1Min, amount0Min + tickLower, tickUpper = -tickUpper, -tickLower + pair.token0IsNative, pair.token1IsNative = pair.token1IsNative, pair.token0IsNative + } + + poolPath := computePoolPath(pair.token0, pair.token1, input.fee) + + result = ProcessedMintInput{ + tokenPair: pair, + amount0Desired: new(u256.Uint).Set(amount0Desired), + amount1Desired: new(u256.Uint).Set(amount1Desired), + amount0Min: new(u256.Uint).Set(amount0Min), + amount1Min: new(u256.Uint).Set(amount1Min), + tickLower: tickLower, + tickUpper: tickUpper, + poolPath: poolPath, + } + + return result, nil +} + +// processTokens processes two token paths, validates them, and handles the wrapping of native tokens into wrapped tokens if applicable. +// +// Parameters: +// - token0: The first token path to process. +// - token1: The second token path to process. +// - caller: The address of the user initiating the token processing. +// +// Returns: +// - string: Processed token0 path (potentially modified if it was a native token). +// - string: Processed token1 path (potentially modified if it was a native token). +// - bool: Indicates whether token0 was a native token (`true` if native, `false` otherwise). +// - bool: Indicates whether token1 was a native token (`true` if native, `false` otherwise). +// - uint64: The amount of the native token that was wrapped into the wrapped token. +// +// Behavior: +// 1. Validates the token paths using `validateTokenPath`. +// - Panics with a detailed error if validation fails. +// 2. Checks if `token0` or `token1` is a native token using `isNative`. +// - If a token is native, it is replaced with the wrapped token path (`WRAPPED_WUGNOT`). +// - The native token is then wrapped into the wrapped token using `safeWrapNativeToken`. +// 3. Returns the processed token paths, flags indicating if the tokens were native, and the wrapped amount. +// +// Panics: +// - If `validateTokenPath` fails validation. +// - If wrapping the native token using `safeWrapNativeToken` encounters an issue. +func processTokens( + token0 string, + token1 string, + amount0Desired string, + amount1Desired string, + caller std.Address, +) (string, string, bool, bool, uint64) { + err := validateTokenPath(token0, token1) + if err != nil { + panic(newErrorWithDetail(err, ufmt.Sprintf("token0(%s), token1(%s)", token0, token1))) + } + + token0IsNative := false + token1IsNative := false + wrappedAmount := uint64(0) + + if isNative(token0) { + token0 = consts.WRAPPED_WUGNOT + token0IsNative = true + + wrappedAmount = safeWrapNativeToken(amount0Desired, caller) + } else if isNative(token1) { + token1 = consts.WRAPPED_WUGNOT + token1IsNative = true + + wrappedAmount = safeWrapNativeToken(amount1Desired, caller) + } + + return token0, token1, token0IsNative, token1IsNative, wrappedAmount +} + +// validateTokenPath validates the relationship and format of token paths. +// Ensures that token paths are not identical, not conflicting (e.g., GNOT and WUGNOT), +// and each token path is in a valid format. +// +// Parameters: +// - token0: The first token path to validate. +// - token1: The second token path to validate. +// +// Returns: +// - error: Returns `errInvalidTokenPath` or nil +// +// Example: +// +// validateTokenPath("tokenA", "tokenB") -> nil +// validateTokenPath("tokenA", "tokenA") -> errInvalidTokenPath +// validateTokenPath(GNOT, WUGNOT) -> errInvalidTokenPath +func validateTokenPath(token0, token1 string) error { + if token0 == token1 { + return errInvalidTokenPath + } + if (token0 == consts.GNOT && token1 == consts.WRAPPED_WUGNOT) || + (token0 == consts.WRAPPED_WUGNOT && token1 == consts.GNOT) { + return errInvalidTokenPath + } + if (!isNative(token0) && !isValidTokenPath(token0)) || + (!isNative(token1) && !isValidTokenPath(token1)) { + return errInvalidTokenPath + } + return nil +} + +// isValidTokenPath checks if the provided token path is registered. +// +// This function verifies if the specified token path exists in the system registry +// by invoking the `IsRegistered` method. A path is considered valid if no error +// is returned during the registration check. +// +// Parameters: +// - tokenPath: The string representing the token path to validate. +// +// Returns: +// - bool: Returns `true` if the token path is registered; otherwise, `false`. +func isValidTokenPath(tokenPath string) bool { + return common.IsRegistered(tokenPath) == nil +} + +// parseAmounts converts strings to u256.Uint values for amount0Desired, amount1Desired, amount0Min, and amount1Min. +func parseAmounts(amount0Desired, amount1Desired, amount0Min, amount1Min string) (*u256.Uint, *u256.Uint, *u256.Uint, *u256.Uint) { + return u256.MustFromDecimal(amount0Desired), u256.MustFromDecimal(amount1Desired), u256.MustFromDecimal(amount0Min), u256.MustFromDecimal(amount1Min) +} + +// computePoolPath returns the pool path based on the token pair and fee tier. +// +// This function constructs a unique pool path identifier by utilizing the two token addresses +// and the pool fee tier. It helps identify the specific liquidity pool on the platform. +// +// Parameters: +// - token0: The address of the first token (string). +// - token1: The address of the second token (string). +// - fee: The fee for the liquidity pool (uint32). +// +// Returns: +// - string: A unique path string representing the liquidity pool. +func computePoolPath(token0, token1 string, fee uint32) string { + return pl.GetPoolPath(token0, token1, fee) +} + +// mint creates a new liquidity position by adding liquidity to a pool and minting an NFT representing the position. +// +// This function handles the entire lifecycle of creating a new liquidity position, including adding liquidity +// to a pool, minting an NFT to represent the position, and storing the position's state. +// +// Parameters: +// - params (MintParams): A struct containing all necessary parameters to mint a new liquidity position, including: +// - token0, token1: The addresses of the token pair. +// - fee: The fee tier of the pool. +// - tickLower, tickUpper: The price range (ticks) for the liquidity position. +// - amount0Desired, amount1Desired: Desired amounts of token0 and token1 to provide. +// - amount0Min, amount1Min: Minimum acceptable amounts to prevent slippage. +// - caller: The address initiating the mint. The required token amounts (token0 and token1) will be withdrawn +// from the caller's balance and deposited into the pool. +// - mintTo: The address to receive the newly minted NFT. +// +// Returns: +// - uint64: The token ID of the minted liquidity position NFT. +// - *u256.Uint: The amount of liquidity added to the pool. +// - *u256.Uint: The actual amount of token0 used in the liquidity addition. +// - *u256.Uint: The actual amount of token1 used in the liquidity addition. +// +// Panics: +// - If the liquidity position (tokenId) already exists. +// - If adding liquidity fails due to insufficient amounts or invalid tick ranges. +// +// Notes: +// - This function relies on `addLiquidity` to perform the liquidity calculation and ensure proper slippage checks. +// - The NFT minted is critical for tracking the user's liquidity in the pool. +// - Position state management is handled by `setPosition`, ensuring the uniqueness of the tokenId. +func mint(params MintParams) (uint64, *u256.Uint, *u256.Uint, *u256.Uint) { + poolKey := pl.GetPoolPath(params.token0, params.token1, params.fee) + liquidity, amount0, amount1 := addLiquidity( + AddLiquidityParams{ + poolKey: poolKey, + tickLower: params.tickLower, + tickUpper: params.tickUpper, + amount0Desired: params.amount0Desired, + amount1Desired: params.amount1Desired, + amount0Min: params.amount0Min, + amount1Min: params.amount1Min, + caller: params.caller, + }, + ) + // Ensure liquidity is not zero before minting NFT + if liquidity.IsZero() { + panic(newErrorWithDetail( + errZeroLiquidity, + "Liquidity is zero, cannot mint position.", + )) + } + + tokenId := getNextId() + gnft.Mint(params.mintTo, tokenIdFrom(tokenId)) // owner, tokenId + + pool := pl.GetPoolFromPoolPath(poolKey) + positionKey := computePositionKey(GetOrigPkgAddr(), params.tickLower, params.tickUpper) + + position := Position{ + nonce: u256.Zero(), + operator: consts.ZERO_ADDRESS, + poolKey: poolKey, + tickLower: params.tickLower, + tickUpper: params.tickUpper, + liquidity: liquidity, + feeGrowthInside0LastX128: new(u256.Uint).Set(pool.PositionFeeGrowthInside0LastX128(positionKey)), + feeGrowthInside1LastX128: new(u256.Uint).Set(pool.PositionFeeGrowthInside1LastX128(positionKey)), + tokensOwed0: u256.Zero(), + tokensOwed1: u256.Zero(), + burned: false, + } + + // The tokenId should not exist at the time of minting + updated := setPosition(tokenId, position) + if updated { + panic(newErrorWithDetail( + errPositionExist, + ufmt.Sprintf("tokenId(%d) already exists", tokenId), + )) + } + incrementNextId() + + return tokenId, liquidity, amount0, amount1 +} + +// IncreaseLiquidity increases liquidity of the existing position +// Returns tokenId, liquidity, amount0, amount1, poolPath +// ref: https://docs.gnoswap.io/contracts/position/position.gno#increaseliquidity +func IncreaseLiquidity( + tokenId uint64, + amount0DesiredStr string, + amount1DesiredStr string, + amount0MinStr string, + amount1MinStr string, + deadline int64, +) (uint64, string, string, string, string) { + assertOnlyNotHalted() + assertValidNumberString(amount0DesiredStr) + assertValidNumberString(amount1DesiredStr) + assertValidNumberString(amount0MinStr) + assertValidNumberString(amount1MinStr) + checkDeadline(deadline) + + en.MintAndDistributeGns() + + position := MustGetPosition(tokenId) + token0, token1, _ := splitOf(position.poolKey) + err := validateTokenPath(token0, token1) + if err != nil { + panic(newErrorWithDetail(err, ufmt.Sprintf("token0(%s), token1(%s)", token0, token1))) + } + caller := getPrevAddr() + + wrappedAmount := uint64(0) + if isWrappedToken(token0) { + wrappedAmount = safeWrapNativeToken(amount0DesiredStr, caller) + } else if isWrappedToken(token1) { + wrappedAmount = safeWrapNativeToken(amount1DesiredStr, caller) + } + + amount0Desired, amount1Desired, amount0Min, amount1Min := parseAmounts(amount0DesiredStr, amount1DesiredStr, amount0MinStr, amount1MinStr) + increaseLiquidityParams := IncreaseLiquidityParams{ + tokenId: tokenId, + amount0Desired: amount0Desired, + amount1Desired: amount1Desired, + amount0Min: amount0Min, + amount1Min: amount1Min, + deadline: deadline, + } + + _, liquidity, amount0, amount1, poolPath := increaseLiquidity(increaseLiquidityParams) + + if isWrappedToken(token0) && wrappedAmount > amount0.Uint64() { + // unwrap leftover wugnot + err = unwrap(wrappedAmount-amount0.Uint64(), caller) + if err != nil { + panic(newErrorWithDetail(errWrapUnwrap, err.Error())) + } + } + if isWrappedToken(token1) && wrappedAmount > amount1.Uint64() { + // unwrap leftover wugnot + err = unwrap(wrappedAmount-amount1.Uint64(), caller) + if err != nil { + panic(newErrorWithDetail(errWrapUnwrap, err.Error())) + } + } + + poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(poolPath) + token0Balance := pl.PoolGetBalanceToken0(poolPath) + token1Balance := pl.PoolGetBalanceToken1(poolPath) + + prevAddr, prevPkgPath := getPrevAsString() + + std.Emit( + "IncreaseLiquidity", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "poolPath", poolPath, + "caller", caller.String(), + "lpTokenId", formatUint(tokenId), + "liquidity", liquidity.ToString(), + "amount0", amount0.ToString(), + "amount1", amount1.ToString(), + "sqrtPriceX96", poolSqrtPriceX96, + "positionLiquidity", position.liquidity.ToString(), + "token0Balance", token0Balance.ToString(), + "token1Balance", token1Balance.ToString(), + ) + + return tokenId, liquidity.ToString(), amount0.ToString(), amount1.ToString(), poolPath +} + +// FeeGrowthInside represents fee growth inside ticks +type FeeGrowthInside struct { + feeGrowthInside0LastX128 *u256.Uint + feeGrowthInside1LastX128 *u256.Uint +} + +// PositionFeeUpdate represents fee update calculation result +type PositionFeeUpdate struct { + tokensOwed0 *u256.Uint + tokensOwed1 *u256.Uint + feeGrowthInside0LastX128 *u256.Uint + feeGrowthInside1LastX128 *u256.Uint +} + +func calculatePositionFeeUpdate( + position Position, + currentFeeGrowth FeeGrowthInside, +) PositionFeeUpdate { + tokensOwed0 := calculateTokensOwed( + currentFeeGrowth.feeGrowthInside0LastX128, + position.feeGrowthInside0LastX128, + position.liquidity, + ) + + tokensOwed1 := calculateTokensOwed( + currentFeeGrowth.feeGrowthInside1LastX128, + position.feeGrowthInside1LastX128, + position.liquidity, + ) + + return PositionFeeUpdate{ + tokensOwed0: new(u256.Uint).Add(position.tokensOwed0, tokensOwed0), + tokensOwed1: new(u256.Uint).Add(position.tokensOwed1, tokensOwed1), + feeGrowthInside0LastX128: new(u256.Uint).Set(currentFeeGrowth.feeGrowthInside0LastX128), + feeGrowthInside1LastX128: new(u256.Uint).Set(currentFeeGrowth.feeGrowthInside1LastX128), + } +} + +// updatePosition updates the position with new liquidity and fee data +func updatePosition( + position Position, + feeUpdate PositionFeeUpdate, + newLiquidity *u256.Uint, +) Position { + position.tokensOwed0 = feeUpdate.tokensOwed0 + position.tokensOwed1 = feeUpdate.tokensOwed1 + position.feeGrowthInside0LastX128 = feeUpdate.feeGrowthInside0LastX128 + position.feeGrowthInside1LastX128 = feeUpdate.feeGrowthInside1LastX128 + position.liquidity = new(u256.Uint).Add(position.liquidity, newLiquidity) + position.burned = false + + return position +} + +func increaseLiquidity(params IncreaseLiquidityParams) (uint64, *u256.Uint, *u256.Uint, *u256.Uint, string) { + // verify tokenId exists + assertTokenExists(params.tokenId) + caller := getPrevAddr() + assertOnlyOwnerOfToken(params.tokenId, caller) + + position := MustGetPosition(params.tokenId) + liquidity, amount0, amount1 := addLiquidity( + AddLiquidityParams{ + poolKey: position.poolKey, + tickLower: position.tickLower, + tickUpper: position.tickUpper, + amount0Desired: params.amount0Desired, + amount1Desired: params.amount1Desired, + amount0Min: params.amount0Min, + amount1Min: params.amount1Min, + caller: caller, + }, + ) + + pool := pl.GetPoolFromPoolPath(position.poolKey) + positionKey := computePositionKey(GetOrigPkgAddr(), position.tickLower, position.tickUpper) + feeGrowthInside0LastX128 := new(u256.Uint).Set(pool.PositionFeeGrowthInside0LastX128(positionKey)) + feeGrowthInside1LastX128 := new(u256.Uint).Set(pool.PositionFeeGrowthInside1LastX128(positionKey)) + + { + diff := new(u256.Uint).Sub(feeGrowthInside0LastX128, position.feeGrowthInside0LastX128) + mulDiv := u256.MulDiv(diff, new(u256.Uint).Set(position.liquidity), u256.MustFromDecimal(consts.Q128)) + + position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, mulDiv) + } + + { + diff := new(u256.Uint).Sub(feeGrowthInside1LastX128, position.feeGrowthInside1LastX128) + mulDiv := u256.MulDiv(diff, new(u256.Uint).Set(position.liquidity), u256.MustFromDecimal(consts.Q128)) + + position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, mulDiv) + } + + position.feeGrowthInside0LastX128 = feeGrowthInside0LastX128 + position.feeGrowthInside1LastX128 = feeGrowthInside1LastX128 + position.liquidity = new(u256.Uint).Add(new(u256.Uint).Set(position.liquidity), liquidity) + position.burned = false + + updated := setPosition(params.tokenId, position) + if !updated { + panic(newErrorWithDetail( + errPositionDoesNotExist, + ufmt.Sprintf("can not increase liquidity for non-existent position(%d)", params.tokenId), + )) + } + + return params.tokenId, liquidity, amount0, amount1, position.poolKey +} + +// DecreaseLiquidity decreases liquidity of the existing position +// It also handles the conversion between GNOT and WUGNOT transparently for the user. +// Returns tokenId, liquidity, fee0, fee1, amount0, amount1, poolPath +// ref: https://docs.gnoswap.io/contracts/position/position.gno#decreaseliquidity +func DecreaseLiquidity( + tokenId uint64, + liquidityStr string, + amount0MinStr string, + amount1MinStr string, + deadline int64, + unwrapResult bool, +) (uint64, string, string, string, string, string, string) { + assertOnlyNotHalted() + isAuthorizedForToken(tokenId) + checkDeadline(deadline) + assertValidLiquidityAmount(liquidityStr) + + en.MintAndDistributeGns() + + amount0Min := u256.MustFromDecimal(amount0MinStr) + amount1Min := u256.MustFromDecimal(amount1MinStr) + decreaseLiquidityParams := DecreaseLiquidityParams{ + tokenId: tokenId, + liquidity: liquidityStr, + amount0Min: amount0Min, + amount1Min: amount1Min, + deadline: deadline, + unwrapResult: unwrapResult, + } + + tokenId, liquidity, fee0, fee1, amount0, amount1, poolPath := decreaseLiquidity(decreaseLiquidityParams) + + poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(poolPath) + token0Balance := pl.PoolGetBalanceToken0(poolPath) + token1Balance := pl.PoolGetBalanceToken1(poolPath) + + prevAddr, prevPkgPath := getPrevAsString() + + std.Emit( + "DecreaseLiquidity", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "lpTokenId", formatUint(tokenId), + "poolPath", poolPath, + "decreasedLiquidity", liquidity.ToString(), + "feeAmount0", fee0.ToString(), + "feeAmount1", fee1.ToString(), + "amount0", amount0.ToString(), + "amount1", amount1.ToString(), + "unwrapResult", formatBool(unwrapResult), + "sqrtPriceX96", poolSqrtPriceX96, + "positionLiquidity", PositionGetPositionLiquidityStr(tokenId), + "token0Balance", token0Balance.ToString(), + "token1Balance", token1Balance.ToString(), + ) + + return tokenId, liquidity.ToString(), fee0.ToString(), fee1.ToString(), amount0.ToString(), amount1.ToString(), poolPath +} + +// decreaseLiquidity reduces the liquidity of a given position and collects the corresponding tokens. +// If unwrapResult is true and the position involves WUGNOT, any leftover WUGNOT will be +// unwrapped to GNOT at the end of the operation. +// Returns tokenId, liquidity, fee0, fee1, amount0, amount1, poolPath +func decreaseLiquidity(params DecreaseLiquidityParams) (uint64, *u256.Uint, *u256.Uint, *u256.Uint, *u256.Uint, *u256.Uint, string) + verifyTokenIdAndOwnership(params.tokenId) + + // BEFORE DECREASE LIQUIDITY, COLLECT FEE FIRST + _, fee0Str, fee1Str, _, _, _ := CollectFee(params.tokenId, params.unwrapResult) + fee0 := u256.MustFromDecimal(fee0Str) + fee1 := u256.MustFromDecimal(fee1Str) + + position := MustGetPosition(params.tokenId) + positionLiquidity := position.liquidity + if positionLiquidity.IsZero() { + panic(newErrorWithDetail( + errZeroLiquidity, + ufmt.Sprintf("position(tokenId:%d) has 0 liquidity", params.tokenId), + )) + } + + liquidityToRemove := u256.MustFromDecimal(params.liquidity) + if liquidityToRemove.Gt(positionLiquidity) { + panic(newErrorWithDetail( + errInvalidLiquidity, + ufmt.Sprintf("Liquidity requested(%s) is greater than liquidity held(%s)", liquidityToRemove.ToString(), positionLiquidity.ToString()), + )) + } + + caller := getPrevAddr() + // NOTE: XXX + // Currently, wugnot is using users.AddressOrName so we have no choice but to change type for it. + // When, wugnot is using std.Address, we can remove converting. + // This applies to every wugnot related calls. + beforeWugnotBalance := wugnot.BalanceOf(common.AddrToUser(caller)) // before unwrap + + pToken0, pToken1, pFee := splitOf(position.poolKey) + burn0, burn1 := pl.Burn(pToken0, pToken1, pFee, position.tickLower, position.tickUpper, liquidityToRemove.ToString()) + + burnedAmount0 := u256.MustFromDecimal(burn0) + burnedAmount1 := u256.MustFromDecimal(burn1) + verifySlippageAmounts(burnedAmount0, burnedAmount1, params.amount0Min, params.amount1Min) + + positionKey := computePositionKey(GetOrigPkgAddr(), position.tickLower, position.tickUpper) + pool := pl.GetPoolFromPoolPath(position.poolKey) + feeGrowthInside0LastX128 := new(u256.Uint).Set(pool.PositionFeeGrowthInside0LastX128(positionKey)) + feeGrowthInside1LastX128 := new(u256.Uint).Set(pool.PositionFeeGrowthInside1LastX128(positionKey)) + + position.tokensOwed0 = updateTokensOwed( + feeGrowthInside0LastX128, + position.feeGrowthInside0LastX128, + position.liquidity, + burnedAmount0, + position.tokensOwed0, + ) + + position.tokensOwed1 = updateTokensOwed( + feeGrowthInside1LastX128, + position.feeGrowthInside1LastX128, + position.liquidity, + burnedAmount1, + position.tokensOwed1, + ) + + position.feeGrowthInside0LastX128 = feeGrowthInside0LastX128 + position.feeGrowthInside1LastX128 = feeGrowthInside1LastX128 + position.liquidity = new(u256.Uint).Sub(positionLiquidity, liquidityToRemove) + mustUpdatePosition(params.tokenId, position) + + collect0, collect1 := pl.Collect( + pToken0, + pToken1, + pFee, + caller, + position.tickLower, + position.tickUpper, + burn0, + burn1, + ) + + collectAmount0 := u256.MustFromDecimal(collect0) + collectAmount1 := u256.MustFromDecimal(collect1) + + underflow := false + position.tokensOwed0, underflow = new(u256.Uint).SubOverflow(position.tokensOwed0, collectAmount0) + if underflow { + panic(newErrorWithDetail( + errUnderflow, + "tokensOwed0 underflow", + )) + } + position.tokensOwed1, underflow = new(u256.Uint).SubOverflow(position.tokensOwed1, collectAmount1) + if underflow { + panic(newErrorWithDetail( + errUnderflow, + "tokensOwed1 underflow", + )) + } + mustUpdatePosition(params.tokenId, position) + + if position.isClear() { + burnPosition(params.tokenId) // just update flag (we don't want to burn actual position) + } + + // NO UNWRAP + if !params.unwrapResult { + return params.tokenId, liquidityToRemove, fee0, fee1, collectAmount0, collectAmount1, position.poolKey + } + + handleUnwrap(pToken0, pToken1, params.unwrapResult, beforeWugnotBalance, caller) + + return params.tokenId, liquidityToRemove, fee0, fee1, collectAmount0, collectAmount1, position.poolKey + + +// CollectFee collects swap fee from the position +// Returns tokenId, afterFee0, afterFee1, poolPath, origFee0, origFee1 +// ref: https://docs.gnoswap.io/contracts/position/position.gno#collectfee +func CollectFee(tokenId uint64, unwrapResult bool) (uint64, string, string, string, string, string) { + assertOnlyNotHalted() + assertTokenExists(tokenId) + isAuthorizedForToken(tokenId) + + en.MintAndDistributeGns() + + // verify position + position := MustGetPosition(tokenId) + token0, token1, fee := splitOf(position.poolKey) + + pl.Burn( + token0, + token1, + fee, + position.tickLower, + position.tickUpper, + ZERO_LIQUIDITY_FOR_FEE_COLLECTION, // burn '0' liquidity to collect fee + ) + + currentFeeGrowth, err := getCurrentFeeGrowth(position, token0, token1, fee) + if err != nil { + panic(newErrorWithDetail(err, "failed to get current fee growth")) + } + + tokensOwed0, tokensOwed1 := calculateFees(position, currentFeeGrowth) + + position.feeGrowthInside0LastX128 = new(u256.Uint).Set(currentFeeGrowth.feeGrowthInside0LastX128) + position.feeGrowthInside1LastX128 = new(u256.Uint).Set(currentFeeGrowth.feeGrowthInside1LastX128) + + // check user wugnot amount + // need this value to unwrap fee + caller := getPrevAddr() + userWugnot := wugnot.BalanceOf(common.AddrToUser(caller)) + + // collect fee + amount0, amount1 := pl.Collect( + token0, token1, fee, + caller, + position.tickLower, position.tickUpper, + tokensOwed0.ToString(), tokensOwed1.ToString(), + ) + + // sometimes there will be a few less uBase amount than expected due to rounding down in core, but we just subtract the full amount expected + // instead of the actual amount so we can burn the token + position.tokensOwed0 = new(u256.Uint).Sub(tokensOwed0, u256.MustFromDecimal(amount0)) + position.tokensOwed1 = new(u256.Uint).Sub(tokensOwed1, u256.MustFromDecimal(amount1)) + mustUpdatePosition(tokenId, position) + + // handle withdrawal fee + withoutFee0, withoutFee1 := pl.HandleWithdrawalFee( + tokenId, + token0, + amount0, + token1, + amount1, + position.poolKey, + std.PrevRealm().Addr(), + ) + + // UNWRAP + pToken0, pToken1, _ := splitOf(position.poolKey) + handleUnwrap(pToken0, pToken1, unwrapResult, userWugnot, caller) + + prevAddr, prevPkgPath := getPrevAsString() + + std.Emit( + "CollectSwapFee", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "lpTokenId", formatUint(tokenId), + "feeAmount0", withoutFee0, + "feeAmount1", withoutFee1, + "poolPath", position.poolKey, + "unwrapResult", formatBool(unwrapResult), + ) + + return tokenId, withoutFee0, withoutFee1, position.poolKey, amount0, amount1 +} + +// calculateFees calculates the fees for the current position. +func calculateFees(position Position, currentFeeGrowth FeeGrowthInside) (*u256.Uint, *u256.Uint) { + fee0 := calculateTokensOwed( + currentFeeGrowth.feeGrowthInside0LastX128, + position.feeGrowthInside0LastX128, + position.liquidity, + ) + + fee1 := calculateTokensOwed( + currentFeeGrowth.feeGrowthInside1LastX128, + position.feeGrowthInside1LastX128, + position.liquidity, + ) + + tokensOwed0 := new(u256.Uint).Add(new(u256.Uint).Set(position.tokensOwed0), fee0) + tokensOwed1 := new(u256.Uint).Add(new(u256.Uint).Set(position.tokensOwed1), fee1) + + return tokensOwed0, tokensOwed1 +} + +func getCurrentFeeGrowth(postion Position, token0, token1 string, fee uint32) (FeeGrowthInside, error) { + pool := pl.GetPoolFromPoolPath(postion.poolKey) + positionKey := computePositionKey(GetOrigPkgAddr(), postion.tickLower, postion.tickUpper) + + feeGrowthInside0 := new(u256.Uint).Set(pool.PositionFeeGrowthInside0LastX128(positionKey)) + feeGrowthInside1 := new(u256.Uint).Set(pool.PositionFeeGrowthInside1LastX128(positionKey)) + + feeGrowthInside := FeeGrowthInside{ + feeGrowthInside0LastX128: feeGrowthInside0, + feeGrowthInside1LastX128: feeGrowthInside1, + } + + return feeGrowthInside, nil +} + +// Repositiomn adjusts the position of an existing liquidity token +// by changing its price range and liquidity amount. +// Returns tokenId, liquidity, tickLower, tickUpper, amount0, amount1 +// ref: https://docs.gnoswap.io/contracts/position/position.gno#reposition +func Reposition( + tokenId uint64, + tickLower int32, + tickUpper int32, + amount0DesiredStr string, + amount1DesiredStr string, + amount0MinStr string, + amount1MinStr string, +) (uint64, string, int32, int32, string, string) { + assertOnlyNotHalted() + verifyTokenIdAndOwnership(tokenId) + + en.MintAndDistributeGns() + + // position should be burned to reposition + position := MustGetPosition(tokenId) + oldTickLower := position.tickLower + oldTickUpper := position.tickUpper + + if !(position.isClear()) { + panic(newErrorWithDetail( + errNotClear, + ufmt.Sprintf("position(%d) isn't clear(liquidity:%s, tokensOwed0:%s, tokensOwed1:%s)", tokenId, position.liquidity.ToString(), position.tokensOwed0.ToString(), position.tokensOwed1.ToString()), + )) + } + + caller := getPrevAddr() + token0, token1, fee := splitOf(position.poolKey) + token0, token1, _, _, _ = processTokens(token0, token1, amount0DesiredStr, amount1DesiredStr, caller) + + liquidity, amount0, amount1 := addLiquidity( + AddLiquidityParams{ + poolKey: position.poolKey, + tickLower: tickLower, + tickUpper: tickUpper, + amount0Desired: u256.MustFromDecimal(amount0DesiredStr), + amount1Desired: u256.MustFromDecimal(amount1DesiredStr), + amount0Min: u256.MustFromDecimal(amount0MinStr), + amount1Min: u256.MustFromDecimal(amount1MinStr), + caller: caller, + }, + ) + + // update position tickLower, tickUpper to new value + // because getCurrentFeeGrowth() uses tickLower, tickUpper + position.tickLower = tickLower + position.tickUpper = tickUpper + + currentFeeGrowth, err := getCurrentFeeGrowth(position, token0, token1, fee) + if err != nil { + panic(newErrorWithDetail(err, "failed to get current fee growth")) + } + position.feeGrowthInside0LastX128 = currentFeeGrowth.feeGrowthInside0LastX128 + position.feeGrowthInside1LastX128 = currentFeeGrowth.feeGrowthInside1LastX128 + + position.liquidity = liquidity + // OBS: do not reset feeGrowthInside1LastX128 and feeGrowthInside1LastX128 to zero + // if so, ( decrease 100% -> reposition ) + // > at this point, that position will have unclaimedFee which isn't intended + position.tokensOwed0 = u256.Zero() + position.tokensOwed1 = u256.Zero() + position.burned = false + mustUpdatePosition(tokenId, position) + + poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(position.poolKey) + token0Balance := pl.PoolGetBalanceToken0(poolPath) + token1Balance := pl.PoolGetBalanceToken1(poolPath) + prevAddr, prevPkgPath := getPrevAsString() + + std.Emit( + "Reposition", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "lpTokenId", formatUint(tokenId), + "tickLower", formatInt(int64(tickLower)), + "tickUpper", formatInt(int64(tickUpper)), + "liquidity", liquidity.ToString(), + "amount0", amount0.ToString(), + "amount1", amount1.ToString(), + "prevTickLower", formatInt(int64(oldTickLower)), + "prevTickUpper", formatInt(int64(oldTickUpper)), + "poolPath", position.poolKey, + "sqrtPriceX96", poolSqrtPriceX96, + "positionLiquidity", PositionGetPositionLiquidityStr(tokenId), + "token0Balance", token0Balance.ToString(), + "token1Balance", token1Balance.ToString(), + ) + + return tokenId, liquidity.ToString(), tickLower, tickUpper, amount0.ToString(), amount1.ToString() +} + +func calculateTokensOwed( + feeGrowthInsideLastX128 *u256.Uint, + positionFeeGrowthInsideLastX128 *u256.Uint, + positionLiquidity *u256.Uint, +) *u256.Uint { + diff := new(u256.Uint).Sub(feeGrowthInsideLastX128, positionFeeGrowthInsideLastX128) + return u256.MulDiv(diff, positionLiquidity, u256.MustFromDecimal(consts.Q128)) +} + +func updateTokensOwed( + feeGrowthInsideLastX128 *u256.Uint, + positionFeeGrowthInsideLastX128 *u256.Uint, + positionLiquidity *u256.Uint, + burnedAmount *u256.Uint, + tokensOwed *u256.Uint, +) *u256.Uint { + additionalTokensOwed := calculateTokensOwed(feeGrowthInsideLastX128, positionFeeGrowthInsideLastX128, positionLiquidity) + add := new(u256.Uint).Add(burnedAmount, additionalTokensOwed) + return new(u256.Uint).Add(tokensOwed, add) +} + +func burnPosition(tokenId uint64) { + position := MustGetPosition(tokenId) + if !(position.isClear()) { + panic(newErrorWithDetail( + errNotClear, + ufmt.Sprintf("position(%d) isn't clear(liquidity:%s, tokensOwed0:%s, tokensOwed1:%s)", tokenId, position.liquidity.ToString(), position.tokensOwed0.ToString(), position.tokensOwed1.ToString()), + )) + } + + position.burned = true + mustUpdatePosition(tokenId, position) +} + +func calculateLiquidityToRemove(positionLiquidity *u256.Uint, liquidityRatio uint64) *u256.Uint { + liquidityToRemove := new(u256.Uint).Mul(positionLiquidity, u256.NewUint(liquidityRatio)) + liquidityToRemove = new(u256.Uint).Div(liquidityToRemove, u256.NewUint(100)) + if positionLiquidity.Lt(liquidityToRemove) || liquidityRatio == 100 { + return positionLiquidity + } + return liquidityToRemove +} + +func SetPositionOperator(tokenId uint64, operator std.Address) { + caller := getPrevRealm().PkgPath() + if caller != consts.STAKER_PATH { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf("caller(%s) is not staker", caller), + )) + } + + position := MustGetPosition(tokenId) + position.operator = operator + mustUpdatePosition(tokenId, position) +} diff --git a/contract/r/gnoswap/position/position_test.gno b/contract/r/gnoswap/position/position_test.gno new file mode 100644 index 000000000..6fe14e956 --- /dev/null +++ b/contract/r/gnoswap/position/position_test.gno @@ -0,0 +1,1421 @@ +package position + +import ( + "std" + "strconv" + "testing" + "time" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gnft" +) + +var ( + mockRegistry = make(map[string]bool) +) + +// MockRegister registers a token path in the mock registry. +func MockRegister(t *testing.T, tokenPath string) { + t.Helper() + mockRegistry[tokenPath] = true +} + +// MockUnregister unregisters a token path in the mock registry. +func MockUnregister(t *testing.T, tokenPath string) { + t.Helper() + delete(mockRegistry, tokenPath) +} + +// IsRegistered checks if a token path is registered in the mock registry. +func IsRegistered(t *testing.T, tokenPath string) error { + t.Helper() + if mockRegistry[tokenPath] { + return nil + } + return errInvalidTokenPath +} + +func newDummyPosition(tokenId uint64) Position { + return Position{ + nonce: u256.NewUint(tokenId), + operator: "user1", + poolKey: "poolKey1", + tickLower: -500, + tickUpper: 500, + liquidity: u256.NewUint(1000), + feeGrowthInside0LastX128: u256.NewUint(10), + feeGrowthInside1LastX128: u256.NewUint(20), + tokensOwed0: u256.NewUint(30), + tokensOwed1: u256.NewUint(40), + burned: false, + } +} + +func TestMustGetPosition(t *testing.T) { + tokenId := uint64(1) + position := newDummyPosition(tokenId) + setPosition(tokenId, position) + + t.Run("Success - Get existing position", func(t *testing.T) { + result := MustGetPosition(tokenId) + uassert.Equal(t, position.nonce.ToString(), result.nonce.ToString()) + uassert.Equal(t, position.operator, result.operator) + uassert.Equal(t, position.poolKey, result.poolKey) + uassert.Equal(t, position.tickLower, result.tickLower) + uassert.Equal(t, position.tickUpper, result.tickUpper) + uassert.Equal(t, position.liquidity.ToString(), result.liquidity.ToString()) + uassert.Equal(t, position.feeGrowthInside0LastX128.ToString(), result.feeGrowthInside0LastX128.ToString()) + uassert.Equal(t, position.feeGrowthInside1LastX128.ToString(), result.feeGrowthInside1LastX128.ToString()) + uassert.Equal(t, position.tokensOwed0.ToString(), result.tokensOwed0.ToString()) + uassert.Equal(t, position.tokensOwed1.ToString(), result.tokensOwed1.ToString()) + uassert.Equal(t, position.burned, result.burned) + }) + + t.Run("Fail - Non-existent position", func(t *testing.T) { + uassert.PanicsWithMessage(t, + "[GNOSWAP-POSITION-013] position does not exist || position with tokenId(999) doesn't exist", + func() { + MustGetPosition(999) + }) + }) +} + +func TestGetPosition(t *testing.T) { + tokenId := uint64(1) + position := newDummyPosition(tokenId) + t.Run("Success - Get existing position", func(t *testing.T) { + result := MustGetPosition(tokenId) + uassert.Equal(t, position.nonce.ToString(), result.nonce.ToString()) + uassert.Equal(t, position.operator, result.operator) + uassert.Equal(t, position.poolKey, result.poolKey) + uassert.Equal(t, position.tickLower, result.tickLower) + uassert.Equal(t, position.tickUpper, result.tickUpper) + uassert.Equal(t, position.liquidity.ToString(), result.liquidity.ToString()) + uassert.Equal(t, position.feeGrowthInside0LastX128.ToString(), result.feeGrowthInside0LastX128.ToString()) + uassert.Equal(t, position.feeGrowthInside1LastX128.ToString(), result.feeGrowthInside1LastX128.ToString()) + uassert.Equal(t, position.tokensOwed0.ToString(), result.tokensOwed0.ToString()) + uassert.Equal(t, position.tokensOwed1.ToString(), result.tokensOwed1.ToString()) + uassert.Equal(t, position.burned, result.burned) + removePosition(tokenId) + }) +} + +func TestSetPosition(t *testing.T) { + tokenId := uint64(300) + position := newDummyPosition(tokenId) + + t.Run("Success - Create new position", func(t *testing.T) { + isNew := setPosition(tokenId, position) + uassert.False(t, isNew) + + storedPosition, _ := GetPosition(tokenId) + uassert.Equal(t, position.nonce.ToString(), storedPosition.nonce.ToString()) + uassert.Equal(t, position.operator, storedPosition.operator) + uassert.Equal(t, position.poolKey, storedPosition.poolKey) + uassert.Equal(t, position.tickLower, storedPosition.tickLower) + uassert.Equal(t, position.tickUpper, storedPosition.tickUpper) + uassert.Equal(t, position.liquidity.ToString(), storedPosition.liquidity.ToString()) + uassert.Equal(t, position.feeGrowthInside0LastX128.ToString(), storedPosition.feeGrowthInside0LastX128.ToString()) + uassert.Equal(t, position.feeGrowthInside1LastX128.ToString(), storedPosition.feeGrowthInside1LastX128.ToString()) + uassert.Equal(t, position.tokensOwed0.ToString(), storedPosition.tokensOwed0.ToString()) + uassert.Equal(t, position.tokensOwed1.ToString(), storedPosition.tokensOwed1.ToString()) + uassert.Equal(t, position.burned, storedPosition.burned) + }) + + t.Run("Success - Update existing position", func(t *testing.T) { + position.tickUpper = 1000 + isNew := setPosition(tokenId, position) + uassert.True(t, isNew) + + updatedPosition, _ := GetPosition(tokenId) + uassert.Equal(t, int32(1000), updatedPosition.tickUpper) + }) +} + +func TestRemovePosition(t *testing.T) { + tokenId := uint64(2) + position := newDummyPosition(tokenId) + + t.Run("Success - Remove existing position", func(t *testing.T) { + setPosition(tokenId, position) + removePosition(tokenId) + _, found := GetPosition(tokenId) + uassert.False(t, found) + }) + + t.Run("Success - Remove non-existent position", func(t *testing.T) { + removePosition(tokenId) + _, found := GetPosition(tokenId) + uassert.False(t, found) + }) +} + +func TestExistPosition(t *testing.T) { + tokenId := uint64(2) + position := newDummyPosition(tokenId) + + t.Run("False - Position does not exist", func(t *testing.T) { + uassert.False(t, existPosition(tokenId)) + }) + + t.Run("True - Position exists", func(t *testing.T) { + setPosition(tokenId, position) + uassert.True(t, existPosition(tokenId)) + removePosition(tokenId) + }) +} + +func TestComputePositionKey(t *testing.T) { + tests := []struct { + name string + owner std.Address + tickLower int32 + tickUpper int32 + expected string + }{ + { + name: "Basic Position Key", + owner: alice, + tickLower: -100, + tickUpper: 200, + expected: "ZzF2OWt4amNtOXRhMDQ3aDZsdGEwNDdoNmx0YTA0N2g2bHpkNDBnaF9fLTEwMF9fMjAw", // Base64 of "g1v9kxjcm9ta047h6lta047h6lta047h6lzd40gh__-100__200" + }, + { + name: "Zero Ticks", + owner: alice, + tickLower: 0, + tickUpper: 0, + expected: "ZzF2OWt4amNtOXRhMDQ3aDZsdGEwNDdoNmx0YTA0N2g2bHpkNDBnaF9fMF9fMA==", // Base64 of "g1v9kxjcm9ta047h6lta047h6lta047h6lzd40gh__0__0" + }, + { + name: "Negative Lower Tick", + owner: alice, + tickLower: -50, + tickUpper: 150, + expected: "ZzF2OWt4amNtOXRhMDQ3aDZsdGEwNDdoNmx0YTA0N2g2bHpkNDBnaF9fLTUwX18xNTA=", // Base64 of "g1v9kxjcm9ta047h6lta047h6lta047h6lzd40gh__-50__150" + }, + { + name: "Same Tick Bounds", + owner: alice, + tickLower: 300, + tickUpper: 300, + expected: "ZzF2OWt4amNtOXRhMDQ3aDZsdGEwNDdoNmx0YTA0N2g2bHpkNDBnaF9fMzAwX18zMDA=", // Base64 of "g1v9kxjcm9ta047h6lta047h6lta047h6lzd40gh__300__300" + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := computePositionKey(tt.owner, tt.tickLower, tt.tickUpper) + + if result != tt.expected { + t.Errorf("expected %s but got %s", tt.expected, result) + } + }) + } +} + +func TestNextIdFunctions(t *testing.T) { + nextId = uint64(1) + + t.Run("Initial nextId should return 1", func(t *testing.T) { + uassert.Equal(t, uint64(1), getNextId(), "expected nextId to start at 1") + }) + + t.Run("After mint nextId should return 2", func(t *testing.T) { + incrementNextId() + uassert.Equal(t, getNextId(), uint64(2), "expected nextId to be 2 after mint") + }) + + t.Run("Increment nextId once", func(t *testing.T) { + incrementNextId() + uassert.Equal(t, uint64(3), getNextId(), "expected nextId to increment to 3") + }) + + t.Run("Increment nextId multiple times", func(t *testing.T) { + for i := 0; i < 2; i++ { + incrementNextId() + } + uassert.Equal(t, uint64(5), getNextId(), "expected nextId to increment to 5 after 3 more increments") + }) + + t.Run("Ensure no overflow on normal increments", func(t *testing.T) { + for i := uint64(5); i < 100; i++ { + incrementNextId() + } + uassert.Equal(t, uint64(100), getNextId(), "expected nextId to reach 100 after continuous increments") + }) +} + +func TestIsValidTokenPath(t *testing.T) { + tests := []struct { + name string + tokenPath string + register bool + expected bool + }{ + { + name: "Valid Token Path", + tokenPath: gnsPath, + register: true, + expected: true, + }, + { + name: "Invalid Token Path", + tokenPath: "invalid/path", + register: false, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidTokenPath(tt.tokenPath) + uassert.Equal(t, tt.expected, result) + }) + } +} + +func TestValidateTokenPath(t *testing.T) { + tests := []struct { + name string + token0 string + token1 string + register0 bool + register1 bool + expected error + }{ + { + name: "Valid Token Path", + token0: gnsPath, + token1: barPath, + register0: true, + register1: true, + expected: nil, + }, + { + name: "Same Token Path", + token0: "tokenA", + token1: "tokenA", + register0: true, + register1: true, + expected: errInvalidTokenPath, + }, + { + name: "Conflicting Tokens (GNOT ↔ WUGNOT)", + token0: consts.GNOT, + token1: consts.WRAPPED_WUGNOT, + register0: true, + register1: true, + expected: errInvalidTokenPath, + }, + { + name: "Invalid Token Path", + token0: "tokenX", + token1: "tokenY", + register0: false, + register1: true, + expected: errInvalidTokenPath, + }, + { + name: "Both Invalid Tokens", + token0: "invalidA", + token1: "invalidB", + register0: false, + register1: false, + expected: errInvalidTokenPath, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.register0 { + MockRegister(t, tt.token0) + } else { + MockUnregister(t, tt.token0) + } + + if tt.register1 { + MockRegister(t, tt.token1) + } else { + MockUnregister(t, tt.token1) + } + + if tt.expected != nil { + err := validateTokenPath(tt.token0, tt.token1) + uassert.Equal(t, tt.expected.Error(), err.Error()) + } else { + uassert.NotPanics(t, func() { + err := validateTokenPath(tt.token0, tt.token1) + if err != nil { + t.Errorf("expected no error but got %s", err.Error()) + } + }) + } + }) + } +} + +func TestProcessTokens(t *testing.T) { + tests := []struct { + name string + token0 string + token1 string + amount0Desired string + amount1Desired string + caller std.Address + expected0 string + expected1 string + isNative0 bool + isNative1 bool + expectedWrap uint64 + expectPanic bool + expectMsg string + }{ + { + name: "Both tokens valid and not native", + token0: gnsPath, + token1: "tokenB", + amount0Desired: "100", + amount1Desired: "200", + caller: alice, + expected0: "tokenA", + expected1: "tokenB", + isNative0: false, + isNative1: false, + expectedWrap: 0, + expectPanic: true, + expectMsg: "[GNOSWAP-POSITION-016] invalid token address || token0(gno.land/r/gnoswap/v1/gns), token1(tokenB)", + }, + { + name: "token0 is native", + token0: consts.GNOT, + token1: gnsPath, + amount0Desired: "1300", + amount1Desired: "200", + caller: alice, + expected0: consts.WRAPPED_WUGNOT, + expected1: "gno.land/r/gnoswap/v1/gns", + isNative0: true, + isNative1: false, + expectedWrap: 1300, + expectPanic: false, + expectMsg: "[GNOSWAP-POSITION-016] invalid token address || token0(gnot), token1(gno.land/r/gnoswap/v1/gns)", + }, + { + name: "token1 is native", + token0: gnsPath, + token1: consts.GNOT, + amount0Desired: "150", + amount1Desired: "1250", + caller: testutils.TestAddress("user3"), + expected0: "gno.land/r/gnoswap/v1/gns", + expected1: consts.WRAPPED_WUGNOT, + isNative0: false, + isNative1: true, + expectedWrap: 1250, + expectPanic: false, + }, + { + name: "Both tokens are native", + token0: consts.GNOT, + token1: consts.GNOT, + amount0Desired: "1100", + amount1Desired: "1200", + caller: testutils.TestAddress("user4"), + expected0: consts.WRAPPED_WUGNOT, + expected1: consts.WRAPPED_WUGNOT, + isNative0: true, + isNative1: true, + expectedWrap: 2300, + expectPanic: true, + expectMsg: "[GNOSWAP-POSITION-016] invalid token address || token0(gnot), token1(gnot)", + }, + { + name: "Invalid token path", + token0: "invalidToken", + token1: "tokenB", + amount0Desired: "150", + amount1Desired: "200", + caller: testutils.TestAddress("user5"), + expectPanic: true, + expectMsg: "[GNOSWAP-POSITION-016] invalid token address || token0(invalidToken), token1(tokenB)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + MockRegister(t, tt.token0) + MockRegister(t, tt.token1) + + defer func() { + if r := recover(); r != nil { + if !tt.expectPanic { + t.Errorf("unexpected panic: %v", r) + } + } + }() + + if tt.token0 == consts.GNOT { + amount, _ := strconv.ParseUint(tt.amount0Desired, 10, 64) + ugnotFaucet(t, consts.POSITION_ADDR, amount) + std.TestSetRealm(std.NewUserRealm(consts.POSITION_ADDR)) + transferUGNOT(consts.POSITION_ADDR, consts.POSITION_ADDR, amount) + } + if tt.token1 == consts.GNOT { + amount, _ := strconv.ParseUint(tt.amount1Desired, 10, 64) + ugnotFaucet(t, consts.POSITION_ADDR, amount) + std.TestSetRealm(std.NewUserRealm(consts.POSITION_ADDR)) + transferUGNOT(consts.POSITION_ADDR, consts.POSITION_ADDR, amount) + } + + if !tt.expectPanic { + token0, token1, native0, native1, wrapped := processTokens( + tt.token0, + tt.token1, + tt.amount0Desired, + tt.amount1Desired, + tt.caller, + ) + + uassert.Equal(t, tt.expected0, token0) + uassert.Equal(t, tt.expected1, token1) + uassert.Equal(t, tt.isNative0, native0) + uassert.Equal(t, tt.isNative1, native1) + uassert.Equal(t, tt.expectedWrap, wrapped) + } else { + uassert.PanicsWithMessage(t, tt.expectMsg, func() { + processTokens(tt.token0, tt.token1, tt.amount0Desired, tt.amount1Desired, tt.caller) + }) + } + }) + } +} + +func TestProcessMintInput(t *testing.T) { + tests := []struct { + name string + input MintInput + expectedToken0 string + expectedToken1 string + expectedAmount0 string + expectedAmount1 string + expectedTickL int32 + expectedTickU int32 + expectError bool + }{ + { + name: "Standard Mint - Token0 < Token1", + input: MintInput{ + token0: gnsPath, + token1: barPath, + amount0Desired: "1000", + amount1Desired: "2000", + amount0Min: "800", + amount1Min: "1800", + tickLower: -10000, + tickUpper: 10000, + caller: alice, + }, + expectedToken0: gnsPath, + expectedToken1: barPath, + expectedAmount0: "1000", + expectedAmount1: "2000", + expectedTickL: -10000, + expectedTickU: 10000, + expectError: false, + }, + { + name: "Token Swap - Token1 < Token0", + input: MintInput{ + token0: barPath, + token1: gnsPath, + amount0Desired: "2000", + amount1Desired: "1000", + amount0Min: "1800", + amount1Min: "800", + tickLower: -20000, + tickUpper: 20000, + caller: alice, + }, + expectedToken0: gnsPath, + expectedToken1: barPath, + expectedAmount0: "1000", + expectedAmount1: "2000", + expectedTickL: -20000, + expectedTickU: 20000, + expectError: false, + }, + { + name: "Error Case - Invalid Amounts", + input: MintInput{ + token0: gnsPath, + token1: barPath, + amount0Desired: "invalid", + amount1Desired: "2000", + amount0Min: "800", + amount1Min: "1800", + tickLower: -5000, + tickUpper: 5000, + caller: alice, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.expectError { + t.Errorf("unexpected panic: %v", r) + } else { + uassert.Equal(t, "[GNOSWAP-POSITION-005] invalid input data || input string : invalid", r) + } + } + }() + processed, err := processMintInput(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got nil") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + uassert.Equal(t, tt.expectedToken0, processed.tokenPair.token0) + uassert.Equal(t, tt.expectedToken1, processed.tokenPair.token1) + uassert.Equal(t, tt.expectedAmount0, processed.amount0Desired.ToString()) + uassert.Equal(t, tt.expectedAmount1, processed.amount1Desired.ToString()) + uassert.Equal(t, tt.expectedTickL, processed.tickLower) + uassert.Equal(t, tt.expectedTickU, processed.tickUpper) + } + }) + } +} + +func TestMintInternal(t *testing.T) { + MakeMintPositionWithoutFee(t) + TokenFaucet(t, fooPath, alice) + TokenFaucet(t, barPath, alice) + TokenApprove(t, fooPath, alice, pool, maxApprove) + TokenApprove(t, barPath, alice, pool, maxApprove) + + tests := []struct { + name string + params MintParams + expectedTokenId uint64 + expectedLiquidity string + expectedAmount0 string + expectedAmount1 string + expectPanic bool + expectedError string + tokenId uint64 + }{ + { + name: "Successful Mint", + params: MintParams{ + token0: barPath, + token1: fooPath, + fee: fee500, + tickLower: -100, + tickUpper: 100, + amount0Desired: u256.MustFromDecimal("10000"), + amount1Desired: u256.MustFromDecimal("10000"), + amount0Min: u256.MustFromDecimal("10"), + amount1Min: u256.MustFromDecimal("10"), + caller: alice, + mintTo: alice, + }, + expectedTokenId: 101, + expectedLiquidity: "2005104", + expectedAmount0: "10000", + expectedAmount1: "10000", + expectPanic: false, + }, + { + name: "Position Exists", + params: MintParams{ + token0: barPath, + token1: fooPath, + fee: fee500, + tickLower: -100, + tickUpper: 100, + amount0Desired: u256.MustFromDecimal("10000"), + amount1Desired: u256.MustFromDecimal("10000"), + amount0Min: u256.MustFromDecimal("10"), + amount1Min: u256.MustFromDecimal("10"), + caller: alice, + mintTo: alice, + }, + expectPanic: true, + expectedError: "token id already exists", + tokenId: 1, + }, + { + name: "Zero Liquidity Mint", + params: MintParams{ + token0: barPath, + token1: fooPath, + fee: fee500, + tickLower: -200, + tickUpper: 200, + amount0Desired: u256.Zero(), + amount1Desired: u256.Zero(), + amount0Min: u256.NewUint(5), + amount1Min: u256.NewUint(5), + caller: alice, + mintTo: alice, + }, + expectPanic: true, + expectedError: "[GNOSWAP-POOL-010] zero liquidity", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.tokenId != 0 { + nextId = tt.tokenId + } + + if !tt.expectPanic { + std.TestSetRealm(std.NewUserRealm(tt.params.mintTo)) + tokenId, liquidity, amount0, amount1 := mint(tt.params) + uassert.Equal(t, tt.expectedTokenId, tokenId) + uassert.Equal(t, tt.expectedLiquidity, liquidity.ToString()) + uassert.Equal(t, tt.expectedAmount0, amount0.ToString()) + uassert.Equal(t, tt.expectedAmount1, amount1.ToString()) + } else { + uassert.PanicsWithMessage(t, tt.expectedError, func() { + mint(tt.params) + }) + } + }) + } +} + +func TestMint(t *testing.T) { + nextId = gnft.TotalSupply() + 1 + tests := []struct { + name string + token0 string + token1 string + fee uint32 + tickLower int32 + tickUpper int32 + amount0 string + amount1 string + minAmount0 string + minAmount1 string + deadline int64 + mintTo std.Address + caller std.Address + expectPanic bool + expectedError string + expectedAmount0 uint64 + expectedAmount1 uint64 + }{ + { + name: "Success - Valid mint request", + token0: barPath, + token1: fooPath, + fee: fee500, + tickLower: -500, + tickUpper: 500, + amount0: "1000000", + amount1: "2000000", + minAmount0: "950000", + minAmount1: "900000", + deadline: time.Now().Add(10 * time.Minute).Unix(), + mintTo: alice, + caller: alice, + expectPanic: false, + expectedAmount0: 1000000, + expectedAmount1: 1000000, + }, + { + name: "Fail - Deadline exceeded", + token0: barPath, + token1: fooPath, + fee: fee500, + tickLower: -500, + tickUpper: 500, + amount0: "1000000", + amount1: "2000000", + minAmount0: "950000", + minAmount1: "1900000", + deadline: time.Now().Add(-10 * time.Minute).Unix(), + mintTo: alice, + caller: alice, + expectPanic: true, + expectedError: "[GNOSWAP-POSITION-007] transaction expired || transaction too old, now(1234567890) > deadline(1234567290)", + expectedAmount0: 950000, + expectedAmount1: 1900000, + }, + { + name: "Fail - Invalid tick range", + token0: barPath, + token1: fooPath, + fee: fee500, + tickLower: 600, + tickUpper: 500, + amount0: "1000000", + amount1: "2000000", + minAmount0: "950000", + minAmount1: "1900000", + deadline: time.Now().Add(10 * time.Minute).Unix(), + mintTo: alice, + caller: alice, + expectPanic: true, + expectedError: "[GNOSWAP-POOL-024] tickLower is greater than or equal to tickUpper || tickLower(600), tickUpper(500)", + expectedAmount0: 950000, + expectedAmount1: 1900000, + }, + { + name: "Fail - Caller not authorized", + token0: barPath, + token1: fooPath, + fee: fee500, + tickLower: -500, + tickUpper: 500, + amount0: "1000000", + amount1: "2000000", + minAmount0: "950000", + minAmount1: "1900000", + deadline: time.Now().Add(10 * time.Minute).Unix(), + mintTo: admin, + caller: alice, + expectPanic: true, + expectedError: "[GNOSWAP-POSITION-012] invalid address || (g1v9kxjcm9ta047h6lta047h6lta047h6lzd40gh, g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d)", + expectedAmount0: 950000, + expectedAmount1: 1900000, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(tc.caller)) + if tc.expectPanic { + uassert.PanicsWithMessage(t, + tc.expectedError, + func() { + Mint( + tc.token0, + tc.token1, + tc.fee, + tc.tickLower, + tc.tickUpper, + tc.amount0, + tc.amount1, + tc.minAmount0, + tc.minAmount1, + tc.deadline, + tc.mintTo, + tc.caller, + ) + }) + } else { + tokenId, liquidity, amount0, amount1 := Mint( + tc.token0, + tc.token1, + tc.fee, + tc.tickLower, + tc.tickUpper, + tc.amount0, + tc.amount1, + tc.minAmount0, + tc.minAmount1, + tc.deadline, + tc.mintTo, + tc.caller, + ) + + uassert.Equal(t, uint64(5), tokenId) + uassert.NotEmpty(t, liquidity) + t0, _ := strconv.ParseUint(amount0, 10, 64) + t1, _ := strconv.ParseUint(amount1, 10, 64) + uassert.Equal(t, tc.expectedAmount0, t0) + uassert.Equal(t, tc.expectedAmount1, t1) + } + }) + } +} + +func TestIncreaseLiquidityInternal(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(admin)) + position := Position{ + poolKey: "gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500", + tickLower: -10000, + tickUpper: 10000, + liquidity: u256.NewUint(1000000), + feeGrowthInside0LastX128: u256.Zero(), + feeGrowthInside1LastX128: u256.Zero(), + tokensOwed0: u256.Zero(), + tokensOwed1: u256.Zero(), + burned: false, + } + id := getNextId() + setPosition(id, position) + incrementNextId() + + MakeMintPositionWithoutFee(t) + + tests := []struct { + name string + params IncreaseLiquidityParams + expectedLiquidity string + expectedAmount0 string + expectedAmount1 string + expectPanic bool + expectedErrorMsg string + }{ + { + name: "Fail - Position Does Not Exist", + params: IncreaseLiquidityParams{ + tokenId: 999, + amount0Desired: u256.NewUint(1000000), + amount1Desired: u256.NewUint(2000000), + amount0Min: u256.NewUint(900000), + amount1Min: u256.NewUint(1800000), + deadline: time.Now().Add(10 * time.Minute).Unix(), + }, + expectPanic: true, + expectedErrorMsg: "[GNOSWAP-POSITION-006] requested data not found || tokenId(999) doesn't exist", + }, + { + name: "Fail - Zero Liquidity", + params: IncreaseLiquidityParams{ + tokenId: 7, + amount0Desired: u256.Zero(), + amount1Desired: u256.Zero(), + amount0Min: u256.NewUint(900000), + amount1Min: u256.NewUint(1800000), + deadline: time.Now().Add(10 * time.Minute).Unix(), + }, + expectPanic: true, + expectedErrorMsg: "[GNOSWAP-POOL-010] zero liquidity", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expectPanic { + uassert.PanicsWithMessage(t, + tc.expectedErrorMsg, + func() { + increaseLiquidity(tc.params) + }) + } else { + tokenId, liquidity, amount0, amount1, _ := increaseLiquidity(tc.params) + uassert.Equal(t, tc.params.tokenId, tokenId) + uassert.Equal(t, tc.expectedLiquidity, liquidity.ToString()) + uassert.Equal(t, tc.expectedAmount0, amount0.ToString()) + uassert.Equal(t, tc.expectedAmount1, amount1.ToString()) + } + }) + } +} + +func TestIncreaseLiquidity(t *testing.T) { + std.TestSetRealm(std.NewUserRealm(admin)) + tests := []struct { + name string + tokenId uint64 + amount0Desired string + amount1Desired string + amount0Min string + amount1Min string + deadline int64 + expectedAmount0 string + expectedAmount1 string + expectPanic bool + expectedErrorMsg string + }{ + { + name: "Success - Valid Increase", + tokenId: 7, + amount0Desired: "1000000", + amount1Desired: "2000000", + amount0Min: "950000", + amount1Min: "900000", + deadline: time.Now().Add(10 * time.Minute).Unix(), + expectedAmount0: "1000000", + expectedAmount1: "1000000", + expectPanic: false, + }, + { + name: "Fail - Deadline Exceeded", + tokenId: 7, + amount0Desired: "1000000", + amount1Desired: "2000000", + amount0Min: "950000", + amount1Min: "900000", + deadline: time.Now().Add(-10 * time.Minute).Unix(), + expectPanic: true, + expectedErrorMsg: "[GNOSWAP-POSITION-007] transaction expired || transaction too old, now(1234567890) > deadline(1234567290)", + }, + { + name: "Fail - Invalid Amount String", + tokenId: 7, + amount0Desired: "invalid_amount", + amount1Desired: "2000000", + amount0Min: "950000", + amount1Min: "900000", + deadline: time.Now().Add(10 * time.Minute).Unix(), + expectPanic: true, + expectedErrorMsg: "[GNOSWAP-POSITION-005] invalid input data || input string : invalid_amount", + }, + { + name: "Fail - Position Does Not Exist", + tokenId: 999, + amount0Desired: "1000000", + amount1Desired: "2000000", + amount0Min: "950000", + amount1Min: "900000", + deadline: time.Now().Add(10 * time.Minute).Unix(), + expectPanic: true, + expectedErrorMsg: "[GNOSWAP-POSITION-013] position does not exist || position with tokenId(999) doesn't exist", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expectPanic { + uassert.PanicsWithMessage(t, + tc.expectedErrorMsg, + func() { + IncreaseLiquidity( + tc.tokenId, + tc.amount0Desired, + tc.amount1Desired, + tc.amount0Min, + tc.amount1Min, + tc.deadline, + ) + }) + } else { + tokenId, _, amount0, amount1, _ := IncreaseLiquidity( + tc.tokenId, + tc.amount0Desired, + tc.amount1Desired, + tc.amount0Min, + tc.amount1Min, + tc.deadline, + ) + uassert.Equal(t, tc.tokenId, tokenId) + uassert.Equal(t, tc.expectedAmount0, amount0) + uassert.Equal(t, tc.expectedAmount1, amount1) + } + }) + } +} + +func TestDecreaseLiquidity(t *testing.T) { + tests := []struct { + name string + beforeIncrease bool + increaseAmount0 string + increaseAmount1 string + liquidityToRemove string + amount0Min string + amount1Min string + deadlineOffset time.Duration + unwrapResult bool + expectPanic bool + expectedErrorMsg string + expectedLiquidity string + expectedFee0 string + expectedFee1 string + }{ + { + name: "Success - Decrease 50% Liquidity", + liquidityToRemove: "400000", + amount0Min: "8000", + amount1Min: "15000", + deadlineOffset: time.Hour, + unwrapResult: false, + expectPanic: false, + expectedLiquidity: "400000", + expectedFee0: "0", + expectedFee1: "0", + }, + { + name: "Success - Decrease 100% Liquidity", + liquidityToRemove: "600000", + amount0Min: "10000", + amount1Min: "20000", + deadlineOffset: time.Hour, + unwrapResult: false, + expectPanic: false, + expectedLiquidity: "600000", + expectedFee0: "0", + expectedFee1: "0", + expectedErrorMsg: "", + }, + { + name: "Fail - Zero Liquidity", + liquidityToRemove: "0", + amount0Min: "10000", + amount1Min: "20000", + deadlineOffset: time.Hour, + unwrapResult: false, + expectPanic: true, + expectedErrorMsg: "[GNOSWAP-POSITION-010] zero liquidity || liquidity amount must be greater than 0, got 0", + }, + { + name: "Fail - Underflow", + liquidityToRemove: "1541592", + amount0Min: "200000", + amount1Min: "400000", + deadlineOffset: time.Hour, + unwrapResult: false, + expectPanic: true, + expectedErrorMsg: "[GNOSWAP-POOL-010] zero liquidity || both liquidityDelta and current position's liquidity are zero", + }, + { + name: "Fail - Deadline Exceeded", + liquidityToRemove: "50", + amount0Min: "10000", + amount1Min: "20000", + deadlineOffset: -time.Hour, + unwrapResult: false, + expectPanic: true, + expectedErrorMsg: "[GNOSWAP-POSITION-007] transaction expired || transaction too old, now(1234567890) > deadline(1234564290)", + }, + { + name: "Success - Increase and Decrease", + beforeIncrease: true, + increaseAmount0: "500000", + increaseAmount1: "1000000", + liquidityToRemove: "1270796", + amount0Min: "12000", + amount1Min: "25000", + deadlineOffset: time.Hour, + unwrapResult: false, + expectPanic: false, + expectedLiquidity: "1270796", + expectedFee0: "0", + expectedFee1: "0", + }, + } + + CreatePoolWithoutFee(t) + std.TestSetRealm(std.NewUserRealm(admin)) + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, fooPath, admin, pool, consts.UINT64_MAX) + tokenId, _, _, _ := Mint( + barPath, + fooPath, + fee500, + -10000, + 10000, + "1000000", + "1000000", + "0", + "0", + time.Now().Add(time.Hour).Unix(), + admin, + admin, + ) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tc.expectPanic { + t.Errorf("unexpected panic: %v", r) + } else { + uassert.Equal(t, tc.expectedErrorMsg, r) + } + } + }() + + deadline := time.Now().Add(tc.deadlineOffset).Unix() + if tc.beforeIncrease { + _, _, _, _, _ = IncreaseLiquidity( + tokenId, + tc.increaseAmount0, + tc.increaseAmount1, + "0", + "0", + deadline, + ) + } + + if tc.expectPanic { + DecreaseLiquidity( + tokenId, + tc.liquidityToRemove, + tc.amount0Min, + tc.amount1Min, + deadline, + tc.unwrapResult, + ) + } else { + _, liquidity, fee0, fee1, _, _, _ := DecreaseLiquidity( + tokenId, + tc.liquidityToRemove, + tc.amount0Min, + tc.amount1Min, + deadline, + tc.unwrapResult, + ) + uassert.Equal(t, tc.expectedLiquidity, liquidity) + uassert.Equal(t, tc.expectedFee0, fee0) + uassert.Equal(t, tc.expectedFee1, fee1) + } + }) + } +} + +func TestCollectFees(t *testing.T) { + tests := []struct { + name string + mintAmount0 string + mintAmount1 string + increaseAmount0 string + increaseAmount1 string + unwrapResult bool + liquidityToRemove string + expectedFee0 string + expectedFee1 string + expectedAmount0 string + expectedAmount1 string + expectPanic bool + expectedPanicError string + }{ + { + name: "Success - Collect fee after multiple mints", + mintAmount0: "1000000", + mintAmount1: "2000000", + increaseAmount0: "500000", + increaseAmount1: "1000000", + liquidityToRemove: "50", + unwrapResult: false, + expectedFee0: "0", + expectedFee1: "0", + expectedAmount0: "0", + expectedAmount1: "0", + expectPanic: false, + }, + { + name: "Success - Collect fee after full liquidity removal", + mintAmount0: "1000000", + mintAmount1: "2000000", + liquidityToRemove: "100", + unwrapResult: true, + expectedFee0: "0", + expectedFee1: "0", + expectedAmount0: "0", + expectedAmount1: "0", + expectPanic: false, + }, + { + name: "Fail - No liquidity to collect", + mintAmount0: "100000", + mintAmount1: "200000", + liquidityToRemove: "0", + unwrapResult: false, + expectPanic: true, + expectedPanicError: "[GNOSWAP-POSITION-010] zero liquidity || liquidity amount must be greater than 0, got 0", + }, + } + + CreatePoolWithoutFee(t) + std.TestSetRealm(std.NewUserRealm(admin)) + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, fooPath, admin, pool, consts.UINT64_MAX) + + // Mint liquidity + tokenId, _, _, _ := Mint( + barPath, + fooPath, + fee500, + -10000, + 10000, + "1000000", + "1000000", + "0", + "0", + time.Now().Add(time.Hour).Unix(), + admin, + admin, + ) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.expectPanic { + t.Errorf("unexpected panic: %v", r) + } else { + uassert.Equal(t, tt.expectedPanicError, r) + } + } + }() + + // Increase liquidity if specified + if tt.increaseAmount0 != "" && tt.increaseAmount1 != "" { + _, _, _, _, _ = IncreaseLiquidity( + tokenId, + tt.increaseAmount0, + tt.increaseAmount1, + "0", + "0", + time.Now().Add(time.Hour).Unix(), + ) + } + + // Decrease liquidity + _, _, _, _, _, _, _ = DecreaseLiquidity( + tokenId, + tt.liquidityToRemove, + "0", + "0", + time.Now().Add(time.Hour).Unix(), + false, + ) + + if tt.expectPanic { + uassert.PanicsWithMessage(t, tt.expectedPanicError, func() { + CollectFee(tokenId, tt.unwrapResult) + }) + } else { + _, fee0, fee1, _, amount0, amount1 := CollectFee(tokenId, tt.unwrapResult) + uassert.Equal(t, tt.expectedFee0, fee0) + uassert.Equal(t, tt.expectedFee1, fee1) + uassert.Equal(t, tt.expectedAmount0, amount0) + uassert.Equal(t, tt.expectedAmount1, amount1) + } + }) + } +} + +func TestReposition(t *testing.T) { + tests := []struct { + name string + mintAmount0 string + mintAmount1 string + increaseAmount0 string + increaseAmount1 string + liquidityToRemove string + tickLower int32 + tickUpper int32 + expectPanic bool + expectedPanicError string + }{ + { + name: "Fail - Reposition after full liquidity removal", + mintAmount0: "1000000", + mintAmount1: "2000000", + increaseAmount0: "500000", + increaseAmount1: "1000000", + liquidityToRemove: "100", + tickLower: -5000, + tickUpper: 5000, + expectPanic: true, + expectedPanicError: "[GNOSWAP-POSITION-009] position is not clear || position(10) isn't clear(liquidity:3812288, tokensOwed0:0, tokensOwed1:0)", + }, + { + name: "Fail - Reposition with partial liquidity removal", + mintAmount0: "1000000", + mintAmount1: "2000000", + increaseAmount0: "500000", + increaseAmount1: "1000000", + liquidityToRemove: "50", + tickLower: -3000, + tickUpper: 3000, + expectPanic: true, + expectedPanicError: "[GNOSWAP-POSITION-009] position is not clear || position(10) isn't clear(liquidity:5083034, tokensOwed0:0, tokensOwed1:0)", + }, + { + name: "Fail - Reposition on non-existent tokenId", + mintAmount0: "1000000", + mintAmount1: "2000000", + liquidityToRemove: "10", + tickLower: -4000, + tickUpper: 4000, + expectPanic: true, + expectedPanicError: "[GNOSWAP-POSITION-009] position is not clear || position(10) isn't clear(liquidity:5083024, tokensOwed0:0, tokensOwed1:0)", + }, + } + + CreatePoolWithoutFee(t) + std.TestSetRealm(std.NewUserRealm(admin)) + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, fooPath, admin, pool, consts.UINT64_MAX) + + // Step 1: Mint a position + tokenId, _, _, _ := Mint( + barPath, + fooPath, + fee500, + -10000, + 10000, + "1000000", + "1000000", + "0", + "0", + time.Now().Add(time.Hour).Unix(), + admin, + admin, + ) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Step 2: Increase liquidity + if tt.increaseAmount0 != "" && tt.increaseAmount1 != "" { + _, _, _, _, _ = IncreaseLiquidity( + tokenId, + tt.increaseAmount0, + tt.increaseAmount1, + "0", + "0", + time.Now().Add(time.Hour).Unix(), + ) + } + + // Step 3: Decrease liquidity to clear the position + _, _, _, _, _, _, _ = DecreaseLiquidity( + tokenId, + tt.liquidityToRemove, + "0", + "0", + time.Now().Add(time.Hour).Unix(), + false, + ) + + // Step 4: Attempt Reposition + if tt.expectPanic { + uassert.PanicsWithMessage(t, tt.expectedPanicError, func() { + Reposition( + tokenId, + tt.tickLower, + tt.tickUpper, + tt.mintAmount0, + tt.mintAmount1, + "0", + "0", + ) + }) + } else { + //tokenId, liquidity.ToString(), tickLower, tickUpper, amount0.ToString(), amount1.ToString() + tid, _, tickL, tickH, _, _ := Reposition( + tokenId, + tt.tickLower, + tt.tickUpper, + tt.mintAmount0, + tt.mintAmount1, + "0", + "0", + ) + uassert.Equal(t, tid, tokenId) + uassert.Equal(t, tickL, tt.tickLower) + uassert.Equal(t, tickH, tt.tickUpper) + } + }) + } +} diff --git a/contract/r/gnoswap/position/type.gno b/contract/r/gnoswap/position/type.gno new file mode 100644 index 000000000..73fadbf93 --- /dev/null +++ b/contract/r/gnoswap/position/type.gno @@ -0,0 +1,123 @@ +package position + +import ( + "std" + + u256 "gno.land/p/gnoswap/uint256" +) + +// Position represents a liquidity position in a pool. +// Each position tracks the amount of liquidity, fee growth, and tokens owed to the position owner. +type Position struct { + nonce *u256.Uint // nonce for permits + + operator std.Address // address that is approved for spending this token + + poolKey string // poolPath of the pool which this has lp token + + // the tick range of the position, bounds are included + tickLower int32 + tickUpper int32 + + liquidity *u256.Uint // liquidity of the position + + // fee growth of the aggregate position as of the last action on the individual position + feeGrowthInside0LastX128 *u256.Uint + feeGrowthInside1LastX128 *u256.Uint + + // how many uncollected tokens are owed to the position, as of the last computation + tokensOwed0 *u256.Uint + tokensOwed1 *u256.Uint + + burned bool // whether the position has been burned(≈ actuall we don't burn lp token even its empty, just update flag) +} + +// isClear reports whether the position is empty +func (p Position) isClear() bool { + return p.liquidity.IsZero() && p.tokensOwed0.IsZero() && p.tokensOwed1.IsZero() +} + +type MintParams struct { + token0 string // token0 path for a specific pool + token1 string // token1 path for a specific pool + fee uint32 // fee for a specific pool + tickLower int32 // lower end of the tick range for the position + tickUpper int32 // upper end of the tick range for the position + amount0Desired *u256.Uint // desired amount of token0 to be minted + amount1Desired *u256.Uint // desired amount of token1 to be minted + amount0Min *u256.Uint // minimum amount of token0 to be minted + amount1Min *u256.Uint // minimum amount of token1 to be minted + deadline int64 // time by which the transaction must be included to effect the change + mintTo std.Address // address to mint lpToken + caller std.Address // address to call the function +} + +// newMintParams creates `MintParams` from processed input data. +func newMintParams(input ProcessedMintInput, mintInput MintInput) MintParams { + return MintParams{ + token0: input.tokenPair.token0, + token1: input.tokenPair.token1, + fee: mintInput.fee, + tickLower: input.tickLower, + tickUpper: input.tickUpper, + amount0Desired: input.amount0Desired, + amount1Desired: input.amount1Desired, + amount0Min: input.amount0Min, + amount1Min: input.amount1Min, + deadline: mintInput.deadline, + mintTo: mintInput.mintTo, + caller: mintInput.caller, + } +} + +type IncreaseLiquidityParams struct { + tokenId uint64 // tokenId of the position to increase liquidity + amount0Desired *u256.Uint // desired amount of token0 to be minted + amount1Desired *u256.Uint // desired amount of token1 to be minted + amount0Min *u256.Uint // minimum amount of token0 to be minted + amount1Min *u256.Uint // minimum amount of token1 to be minted + deadline int64 // time by which the transaction must be included to effect the change +} + +type DecreaseLiquidityParams struct { + tokenId uint64 // tokenId of the position to decrease liquidity + liquidity string // amount of liquidity to decrease + amount0Min *u256.Uint // minimum amount of token0 to be minted + amount1Min *u256.Uint // minimum amount of token1 to be minted + deadline int64 // time by which the transaction must be included to effect the change + unwrapResult bool // whether to unwrap the token if it's wrapped native token +} + +type MintInput struct { + token0 string + token1 string + fee uint32 + tickLower int32 + tickUpper int32 + amount0Desired string + amount1Desired string + amount0Min string + amount1Min string + deadline int64 + mintTo std.Address + caller std.Address +} + +type TokenPair struct { + token0 string + token1 string + token0IsNative bool + token1IsNative bool + wrappedAmount uint64 +} + +type ProcessedMintInput struct { + tokenPair TokenPair + amount0Desired *u256.Uint + amount1Desired *u256.Uint + amount0Min *u256.Uint + amount1Min *u256.Uint + tickLower int32 + tickUpper int32 + poolPath string +} diff --git a/contract/r/gnoswap/position/utils.gno b/contract/r/gnoswap/position/utils.gno new file mode 100644 index 000000000..5457bf1cc --- /dev/null +++ b/contract/r/gnoswap/position/utils.gno @@ -0,0 +1,386 @@ +package position + +import ( + "std" + "strconv" + "time" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/gnft" + + u256 "gno.land/p/gnoswap/uint256" +) + +// GetOrigPkgAddr returns the original package address. +// In position contract, original package address is the position address. +func GetOrigPkgAddr() std.Address { + return consts.POSITION_ADDR +} + +// assertTokenExists checks if a token with the specified tokenId exists in the system. +// +// This function verifies the existence of a token by its tokenId. If the token does not exist, +// it triggers a panic with a descriptive error message. +// +// Parameters: +// - tokenId (uint64): The unique identifier of the token to check. +// +// Panics: +// - If the token does not exist, it panics with the error `errDataNotFound`. +func assertTokenExists(tokenId uint64) { + if !exists(tokenId) { + panic(newErrorWithDetail( + errDataNotFound, + ufmt.Sprintf("tokenId(%d) doesn't exist", tokenId), + )) + } +} + +// assertOnlyOwnerOfToken ensures that only the owner of the specified token can perform operations on it. +// +// This function checks if the caller is the owner of the token identified by `tokenId`. +// If the caller is not the owner, the function triggers a panic to prevent unauthorized access. +// +// Parameters: +// - tokenId (uint64): The unique identifier of the token to check ownership. +// - caller (std.Address): The address of the entity attempting to modify the token. +// +// Panics: +// - If the caller is not the owner of the token, the function panics with an `errNoPermission` error. +func assertOnlyOwnerOfToken(tokenId uint64, caller std.Address) { + owner, err := gnft.OwnerOf(tokenIdFrom(tokenId)) + if err != nil { + panic(newErrorWithDetail( + errDataNotFound, + ufmt.Sprintf("tokenId(%d) doesn't exist", tokenId), + )) + } + assertCallerIsOwner(tokenId, owner, caller) +} + +func assertCallerIsOwner(tokenId uint64, owner, caller std.Address) { + if owner != caller { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf("caller(%s) is not owner(%s) for tokenId(%d)", caller, owner, tokenId), + )) + } +} + +// assertOnlyUserOrStaker panics if the caller is not a user or staker. +func assertOnlyUserOrStaker(caller std.Realm) { + if !caller.IsUser() { + if err := common.StakerOnly(caller.Addr()); err != nil { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf("from (%s)", caller.Addr()), + )) + } + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertOnlyValidAddress panics if the address is invalid. +func assertOnlyValidAddress(addr std.Address) { + if !addr.IsValid() { + panic(newErrorWithDetail( + errInvalidAddress, + ufmt.Sprintf("(%s)", addr), + )) + } +} + +// assertOnlyValidAddress panics if the address is invalid or previous address is not +// different from the other address. +func assertOnlyValidAddressWith(prevAddr, otherAddr std.Address) { + assertOnlyValidAddress(prevAddr) + assertOnlyValidAddress(otherAddr) + + if prevAddr != otherAddr { + panic(newErrorWithDetail( + errInvalidAddress, + ufmt.Sprintf("(%s, %s)", prevAddr, otherAddr), + )) + } +} + +// assertValidNumberString verifies that the input string represents a valid integer. +// +// This function checks each character in the string to ensure it falls within the numeric ASCII range ('0' to '9'). +// The first character is allowed to be a negative sign ('-'), indicating a negative number. If the input does not +// meet these criteria, the function panics with a detailed error message. +// +// Parameters: +// - input (string): The string to validate. +// +// Panics: +// - If the input string is empty. +// - If the input contains non-numeric characters (excluding an optional leading '-'). +// - If the input is not a valid integer representation. +// +// Example: +// +// assertValidNumberString("12345") -> Pass (valid positive number) +// assertValidNumberString("-98765") -> Pass (valid negative number) +// assertValidNumberString("12a45") -> Panic (invalid character 'a') +// assertValidNumberString("") -> Panic (empty input) +// assertValidNumberString("++123") -> Panic (invalid leading '+') +func assertValidNumberString(input string) { + if len(input) == 0 { + panic(newErrorWithDetail( + errInvalidInput, + ufmt.Sprintf("input is empty"))) + } + + bytes := []byte(input) + for i, b := range bytes { + if i == 0 && b == '-' { + continue // Allow if the first character is a negative sign (-) + } + if b < '0' || b > '9' { + panic(newErrorWithDetail( + errInvalidInput, + ufmt.Sprintf("input string : %s", input))) + } + } +} + +func assertValidLiquidityAmount(liquidity string) { + if u256.MustFromDecimal(liquidity).IsZero() { + panic(newErrorWithDetail( + errZeroLiquidity, + ufmt.Sprintf("liquidity amount must be greater than 0, got %s", liquidity), + )) + } +} + +func assertValidLiquidityRatio(ratio uint64) { + if !(ratio >= 1 && ratio <= 100) { + panic(newErrorWithDetail( + errInvalidLiquidityRatio, + ufmt.Sprintf("liquidity ratio must in range 1 ~ 100(contain), got %d", ratio), + )) + } +} + +// derivePkgAddr derives the Realm address from it's pkgpath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) +} + +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrev returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { + prev := getPrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// checkDeadline checks if the deadline is expired. +// If the deadline is expired, it panics. +// The deadline is expired if the current time is greater than the deadline. +// Input: +// - deadline: the deadline to check +func checkDeadline(deadline int64) { + now := time.Now().Unix() + if now > deadline { + panic(newErrorWithDetail( + errExpired, + ufmt.Sprintf("transaction too old, now(%d) > deadline(%d)", now, deadline), + )) + } +} + +// tokenIdFrom converts tokenId to grc721.TokenID type +// NOTE: input parameter tokenId can be string, int, uint64, or grc721.TokenID +// if tokenId is nil or not supported, it will panic +// if tokenId is not found, it will panic +// input: tokenId interface{} +// output: grc721.TokenID +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic(newErrorWithDetail(errInvalidInput, "tokenId is nil")) + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic(newErrorWithDetail(errInvalidInput, "unsupported tokenId type")) + } +} + +// exists checks whether tokenId exists +// If tokenId doesn't exist, return false, otherwise return true +// input: tokenId uint64 +// output: bool +func exists(tokenId uint64) bool { + return gnft.Exists(tokenIdFrom(tokenId)) +} + +// isOwner checks whether the caller is the owner of the tokenId +// If the caller is the owner of the tokenId, return true, otherwise return false +// input: tokenId uint64, addr std.Address +// output: bool +func isOwner(tokenId uint64, addr std.Address) bool { + owner, err := gnft.OwnerOf(tokenIdFrom(tokenId)) + if err == nil { + if owner == addr { + return true + } + } + return false +} + +// isOperator checks whether the caller is the approved operator of the tokenId +// If the caller is the approved operator of the tokenId, return true, otherwise return false +// input: tokenId uint64, addr std.Address +// output: bool +func isOperator(tokenId uint64, addr std.Address) bool { + operator, err := gnft.GetApproved(tokenIdFrom(tokenId)) + if err == nil && operator == addr { + return true + } + return false +} + +// isStaked checks whether tokenId is staked +// If tokenId is staked, owner of tokenId is staker contract +// If tokenId is staked, return true, otherwise return false +// input: tokenId grc721.TokenID +// output: bool +func isStaked(tokenId grc721.TokenID) bool { + exist := gnft.Exists(tokenId) + if exist { + owner, err := gnft.OwnerOf(tokenId) + if err == nil && owner == consts.STAKER_ADDR { + return true + } + } + return false +} + +// isOwnerOrOperator checks whether the caller is the owner or approved operator of the tokenId +// If the caller is the owner or approved operator of the tokenId, return true, otherwise return false +// input: addr std.Address, tokenId uint64 +// output: bool +func isOwnerOrOperator(addr std.Address, tokenId uint64) bool { + assertOnlyValidAddress(addr) + if !exists(tokenId) { + return false + } + if isOwner(tokenId, addr) || isOperator(tokenId, addr) { + return true + } + if isStaked(tokenIdFrom(tokenId)) { + position, exist := GetPosition(tokenId) + if exist && addr == position.operator { + return true + } + } + return false +} + +func isAuthorizedForToken(tokenId uint64) { + caller := getPrevAddr() + if !(isOwnerOrOperator(caller, tokenId)) { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf("caller(%s) is not approved or owner of tokenId(%d)", caller, tokenId), + )) + } +} + +// splitOf divides poolKey into pToken0, pToken1, and pFee +// If poolKey is invalid, it will panic +// +// input: poolKey string +// output: +// - token0Path string +// - token1Path string +// - fee uint32 +func splitOf(poolKey string) (string, string, uint32) { + res, err := common.Split(poolKey, ":", 3) + if err != nil { + panic(newErrorWithDetail(errInvalidInput, ufmt.Sprintf("invalid poolKey(%s)", poolKey))) + } + pToken0, pToken1, pFeeStr := res[0], res[1], res[2] + + pFee, err := strconv.Atoi(pFeeStr) + if err != nil { + panic(newErrorWithDetail(errInvalidInput, ufmt.Sprintf("invalid fee(%s)", pFeeStr))) + } + return pToken0, pToken1, uint32(pFee) +} + +func verifyTokenIdAndOwnership(tokenId uint64) { + assertTokenExists(tokenId) + assertOnlyOwnerOfToken(tokenId, getPrevAddr()) +} + +func verifySlippageAmounts(amount0, amount1, amount0Min, amount1Min *u256.Uint) { + if !(amount0.Gte(amount0Min) && amount1.Gte(amount1Min)) { + panic(newErrorWithDetail( + errSlippage, + ufmt.Sprintf("amount0(%s) >= amount0Min(%s) && amount1(%s) >= amount1Min(%s)", + amount0.ToString(), amount0Min.ToString(), amount1.ToString(), amount1Min.ToString()), + )) + } +} + +func formatUint(v interface{}) string { + switch v := v.(type) { + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatInt(v interface{}) string { + switch v := v.(type) { + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case int: + return strconv.Itoa(v) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatBool(v bool) string { + return strconv.FormatBool(v) +} diff --git a/contract/r/gnoswap/position/utils_test.gno b/contract/r/gnoswap/position/utils_test.gno new file mode 100644 index 000000000..10d2df684 --- /dev/null +++ b/contract/r/gnoswap/position/utils_test.gno @@ -0,0 +1,949 @@ +package position + +import ( + "fmt" + "std" + "strings" + "testing" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +func assertPanic(t *testing.T, expectedMsg string, fn func()) { + t.Helper() + defer func() { + r := recover() + if r == nil { + t.Errorf("expected panic but got none") + } else if r != expectedMsg { + t.Errorf("expected panic %v, got %v", expectedMsg, r) + } + }() + fn() +} + +func TestGetOrigPkgAddr(t *testing.T) { + tests := []struct { + name string + expected std.Address + }{ + { + name: "Success - getOrigPkgAddr", + expected: consts.POSITION_ADDR, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := GetOrigPkgAddr() + uassert.Equal(t, got, tc.expected) + }) + } +} + +func TestAssertTokenExists(t *testing.T) { + // Mock setup + mockTokenId := uint64(1001) + mockInvalidTokenId := uint64(9999) + + oldExists := exists + // Mock the exists function + mockExists := func(tokenId uint64) bool { + if tokenId == mockTokenId { + return true + } + return false + } + + // Replace exists with mockExists for testing + exists = mockExists + + t.Run("Token Exists - No Panic", func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("unexpected panic for existing tokenId(%d): %v", mockTokenId, r) + } + }() + assertTokenExists(mockTokenId) // Should not panic + }) + + t.Run("Token Does Not Exist - Panic Expected", func(t *testing.T) { + expectedErr := ufmt.Sprintf("[GNOSWAP-POSITION-006] requested data not found || tokenId(%d) doesn't exist", mockInvalidTokenId) + uassert.PanicsWithMessage(t, expectedErr, func() { + assertTokenExists(mockInvalidTokenId) // Should panic + }) + }) + + t.Run("Boundary Case - TokenId Zero", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic for tokenId(0) but did not occur") + } + }() + assertTokenExists(0) // Assume tokenId 0 does not exist + }) + + exists = oldExists +} + +func TestAssertOnlyOwnerOfToken(t *testing.T) { + // Mock token ownership + mockTokenId := uint64(1) + mockOwner := admin + mockCaller := alice + //MakeMintPositionWithoutFee(t) + + t.Run("Token Owned by Caller - No Panic", func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("unexpected panic for tokenId(%d) called by owner(%s): %v", mockTokenId, mockOwner, r) + } + }() + assertOnlyOwnerOfToken(mockTokenId, mockOwner) // Should pass without panic + }) + + t.Run("Token Not Owned by Caller - Panic Expected", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic for tokenId(%d) called by unauthorized user(%s), but no panic occurred", mockTokenId, mockCaller) + } else { + expectedErr := ufmt.Sprintf("caller(%s) is not owner(%s) for tokenId(%d)", mockCaller, mockOwner, mockTokenId) + if !strings.Contains(fmt.Sprintf("%v", r), expectedErr) { + t.Errorf("unexpected error message: %v", r) + } + } + }() + assertOnlyOwnerOfToken(mockTokenId, mockCaller) // Should panic + }) + + t.Run("Non-Existent Token - Panic Expected", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic for non-existent tokenId(9999), but no panic occurred") + } + }() + assertOnlyOwnerOfToken(9999, mockCaller) // Token 9999 does not exist, should panic + }) +} + +func TestAssertOnlyUserOrStaker(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected bool + }{ + { + name: "Failure - Not User or Staker", + originCaller: consts.ROUTER_ADDR, + expected: false, + }, + { + name: "Success - User Call", + originCaller: consts.ADMIN, + expected: true, + }, + { + name: "Success - Staker Call", + originCaller: consts.STAKER_ADDR, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(tc.originCaller) + assertOnlyUserOrStaker(std.PrevRealm()) + }) + } +} + +func TestAssertOnlyNotHalted(t *testing.T) { + tests := []struct { + name string + expected bool + panicMsg string + }{ + { + name: "Failure - Halted", + expected: false, + panicMsg: "[GNOSWAP-COMMON-002] halted || GnoSwap is halted", + }, + { + name: "Success - Not Halted", + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyNotHalted() + }) + } else { + std.TestSetRealm(std.NewUserRealm(admin)) + common.SetHaltByAdmin(true) + uassert.PanicsWithMessage(t, tc.panicMsg, func() { + assertOnlyNotHalted() + }) + common.SetHaltByAdmin(false) + } + }) + } +} + +func TestAssertOnlyValidAddress(t *testing.T) { + tests := []struct { + name string + addr std.Address + expected bool + errorMsg string + }{ + { + name: "Success - valid address", + addr: consts.ADMIN, + expected: true, + }, + { + name: "Failure - invalid address", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", // invalid length + expected: false, + errorMsg: "[GNOSWAP-POSITION-012] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddress(tc.addr) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddress(tc.addr) + }) + } + }) + } +} + +func TestAssertOnlyValidAddressWith(t *testing.T) { + tests := []struct { + name string + addr std.Address + other std.Address + expected bool + errorMsg string + }{ + { + name: "Success - validation address check to compare with other address", + addr: consts.ADMIN, + other: std.Address("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d"), + expected: true, + }, + { + name: "Failure - two address is different", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", + other: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + expected: false, + errorMsg: "[GNOSWAP-POSITION-012] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddressWith(tc.addr, tc.other) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddressWith(tc.addr, tc.other) + }) + } + }) + } +} + +func TestAssertValidNumberString(t *testing.T) { + tests := []struct { + name string + input string + expectPanic bool + }{ + // Valid Cases + { + name: "Valid Positive Number", + input: "12345", + expectPanic: false, + }, + { + name: "Valid Negative Number", + input: "-98765", + expectPanic: false, + }, + { + name: "Zero", + input: "0", + expectPanic: false, + }, + { + name: "Negative Zero", + input: "-0", + expectPanic: false, + }, + + // Invalid Cases + { + name: "Empty String", + input: "", + expectPanic: true, + }, + { + name: "Alphabet in String", + input: "12a45", + expectPanic: true, + }, + { + name: "Special Characters", + input: "123@45", + expectPanic: true, + }, + { + name: "Leading Plus Sign", + input: "+12345", + expectPanic: true, + }, + { + name: "Multiple Negative Signs", + input: "--12345", + expectPanic: true, + }, + { + name: "Space in String", + input: "123 45", + expectPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.expectPanic { + t.Errorf("unexpected panic for input: %s, got: %v", tt.input, r) + } + } else { + if tt.expectPanic { + t.Errorf("expected panic but did not occur for input: %s", tt.input) + } + } + }() + + // Test function + assertValidNumberString(tt.input) + }) + } +} + +func TestDerivePkgAddr(t *testing.T) { + pkgPath := "gno.land/r/gnoswap/v1/position" + tests := []struct { + name string + input string + expected string + }{ + { + name: "Success - derivePkgAddr", + input: pkgPath, + expected: "g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := derivePkgAddr(tc.input) + uassert.Equal(t, got.String(), tc.expected) + }) + } +} + +func TestGetPrevRealm(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prevRealm is User", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevRealm() + uassert.Equal(t, got.Addr().String(), tc.expected[0]) + uassert.Equal(t, got.PkgPath(), tc.expected[1]) + }) + } +} + +func TestGetPrevAddr(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected std.Address + }{ + { + name: "Success - prev Address is User", + originCaller: consts.ADMIN, + expected: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevAddr() + uassert.Equal(t, got.String(), tc.expected.String()) + }) + } +} + +func TestGetPrevAsString(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prev Realm of user info as string", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got1, got2 := getPrevAsString() + uassert.Equal(t, got1, tc.expected[0]) + uassert.Equal(t, got2, tc.expected[1]) + }) + } +} + +func TestIsUserCall(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + originPkgPath string + expected bool + }{ + { + name: "Success - User Call", + originCaller: consts.ADMIN, + expected: true, + }, + { + name: "Failure - Not User Call", + originCaller: consts.ROUTER_ADDR, + originPkgPath: consts.ROUTER_PATH, + expected: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(tc.originCaller) + if !tc.expected { + std.TestSetRealm(std.NewCodeRealm(tc.originPkgPath)) + } + got := isUserCall() + uassert.Equal(t, got, tc.expected) + }) + } +} + +func TestCheckDeadline(t *testing.T) { + tests := []struct { + name string + deadline int64 + now int64 + expected string + }{ + { + name: "Success - checkDeadline", + deadline: 1234567890 + 100, + now: 1234567890, + expected: "", + }, + { + name: "Failure - checkDeadline", + deadline: 1234567890 - 100, + now: 1234567890, + expected: "[GNOSWAP-POSITION-007] transaction expired || transaction too old, now(1234567890) > deadline(1234567790)", + }, + { + name: "Success - deadline equals now", + deadline: 1234567890, + now: 1234567890, + expected: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected != "" { + uassert.PanicsWithMessage(t, tc.expected, func() { + checkDeadline(tc.deadline) + }) + } else { + uassert.NotPanics(t, func() { + checkDeadline(tc.deadline) + }) + } + }) + } +} + +func TestTokenIdFrom(t *testing.T) { + tests := []struct { + name string + input interface{} + expected string + shouldPanic bool + }{ + { + name: "Panic - nil", + input: nil, + expected: "[GNOSWAP-POSITION-005] invalid input data || tokenId is nil", + shouldPanic: true, + }, + { + name: "Panic - unsupported type", + input: float64(1), + expected: "[GNOSWAP-POSITION-005] invalid input data || unsupported tokenId type", + shouldPanic: true, + }, + { + name: "Success - string", + input: "1", + expected: "1", + shouldPanic: false, + }, + { + name: "Success - int", + input: int(1), + expected: "1", + shouldPanic: false, + }, + { + name: "Success - uint64", + input: uint64(1), + expected: "1", + shouldPanic: false, + }, + { + name: "Success - grc721.TokenID", + input: grc721.TokenID("1"), + expected: "1", + shouldPanic: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + got := tokenIdFrom(tc.input) + uassert.Equal(t, tc.expected, string(got)) + } else { + tokenIdFrom(tc.input) + } + }) + } +} + +func TestExists(t *testing.T) { + tests := []struct { + name string + tokenId uint64 + expected bool + }{ + { + name: "Fail - not exists", + tokenId: 300000, + expected: false, + }, + { + name: "Success - exists", + tokenId: 1, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := exists(tc.tokenId) + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestIsOwner(t *testing.T) { + tests := []struct { + name string + tokenId uint64 + addr std.Address + expected bool + }{ + { + name: "Fail - is not owner", + tokenId: 1, + addr: alice, + expected: false, + }, + { + name: "Success - is owner", + tokenId: 1, + addr: admin, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + //MakeMintPositionWithoutFee(t) + got := isOwner(tc.tokenId, tc.addr) + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestIsOperator(t *testing.T) { + tests := []struct { + name string + tokenId uint64 + addr std.Address + expected bool + }{ + { + name: "Fail - is not operator", + tokenId: 1, + addr: alice, + expected: false, + }, + { + name: "Success - is operator", + tokenId: 1, + addr: bob, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + LPTokenApprove(t, admin, tc.addr, tc.tokenId) + } + got := isOperator(tc.tokenId, tc.addr) + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestIsStaked(t *testing.T) { + tests := []struct { + name string + owner std.Address + operator std.Address + tokenId uint64 + expected bool + }{ + { + name: "Fail - is not staked", + owner: bob, + operator: alice, + tokenId: 1, + expected: false, + }, + { + name: "Fail - is not exist tokenId", + owner: admin, + operator: bob, + tokenId: 100, + expected: false, + }, + { + name: "Success - is staked", + owner: admin, + operator: admin, + tokenId: 11, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected && tc.owner == tc.operator { + MakeMintPositionWithoutFee(t) + LPTokenApprove(t, tc.owner, consts.STAKER_ADDR, tc.tokenId) + LPTokenStake(t, tc.owner, tc.tokenId) + } + got := isStaked(tokenIdFrom(tc.tokenId)) + uassert.Equal(t, tc.expected, got) + if tc.expected && tc.owner == tc.operator { + LPTokenUnStake(t, tc.owner, tc.tokenId, false) + } + }) + } +} + +func TestIsOwnerOrOperator(t *testing.T) { + tests := []struct { + name string + owner std.Address + operator std.Address + tokenId uint64 + expected bool + }{ + { + name: "Fail - is not owner or operator", + owner: admin, + operator: alice, + tokenId: 1, + expected: false, + }, + { + name: "Success - is operator", + owner: admin, + operator: bob, + tokenId: 1, + expected: true, + }, + { + name: "Success - is owner", + owner: admin, + operator: admin, + tokenId: 1, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected && tc.owner != tc.operator { + LPTokenApprove(t, tc.owner, tc.operator, tc.tokenId) + } + var got bool + if tc.owner == tc.operator { + got = isOwnerOrOperator(tc.owner, tc.tokenId) + } else { + got = isOwnerOrOperator(tc.operator, tc.tokenId) + } + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestIsOwnerOrOperatorWithStake(t *testing.T) { + tests := []struct { + name string + owner std.Address + operator std.Address + tokenId uint64 + isStake bool + expected bool + }{ + { + name: "Fail - is not token staked", + owner: admin, + operator: alice, + tokenId: 10, + isStake: false, + expected: false, + }, + { + name: "Success - is token staked (position operator)", + owner: admin, + operator: admin, + tokenId: 10, + isStake: true, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.isStake { + tokenId, _, _, _ := MakeMintPositionWithoutFee(t) + LPTokenApprove(t, tc.owner, consts.STAKER_ADDR, tc.tokenId) + LPTokenStake(t, tc.owner, tc.tokenId) + } + got := isOwnerOrOperator(tc.operator, tc.tokenId) + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestPoolKeyDivide(t *testing.T) { + tests := []struct { + name string + poolKey string + expectedPath0 string + expectedPath1 string + expectedFee uint32 + expectedError string + shouldPanic bool + }{ + { + name: "Fail - invalid poolKey", + poolKey: "gno.land/r/onbloc", + expectedError: "[GNOSWAP-POSITION-005] invalid input data || invalid poolKey(gno.land/r/onbloc)", + shouldPanic: true, + }, + { + name: "Success - split poolKey", + poolKey: "gno.land/r/gnoswap/v1/gns:gno.land/r/demo/wugnot:500", + expectedPath0: gnsPath, + expectedPath1: wugnotPath, + expectedFee: fee500, + shouldPanic: false, + }, + { + name: "Fail -empty poolKey", + poolKey: "", + expectedError: "[GNOSWAP-POSITION-005] invalid input data || invalid poolKey()", + shouldPanic: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expectedError { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expectedError) + } + case error: + if r.(error).Error() != tc.expectedError { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expectedError) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expectedError) + } + } + }() + + if !tc.shouldPanic { + gotToken0, gotToken1, gotFee := splitOf(tc.poolKey) + uassert.Equal(t, tc.expectedPath0, gotToken0) + uassert.Equal(t, tc.expectedPath1, gotToken1) + uassert.Equal(t, tc.expectedFee, gotFee) + } else { + splitOf(tc.poolKey) + } + }) + } +} + +func TestSplitOf(t *testing.T) { + tests := []struct { + name string + poolKey string + expectedPath0 string + expectedPath1 string + expectedFee uint32 + expectedError string + shouldPanic bool + }{ + { + name: "Fail - empty poolKey", + poolKey: "", + expectedError: "[GNOSWAP-POSITION-005] invalid input data || invalid poolKey()", + shouldPanic: true, + }, + { + name: "Fail - invalid delimiter", + poolKey: "gno.land/r/gnoswap:v1/gns:gno.land/r/demo/wugnot-500", + expectedError: "[GNOSWAP-POSITION-005] invalid input data || invalid fee(gno.land/r/demo/wugnot-500)", + shouldPanic: true, + }, + { + name: "Fail - non-numeric fee", + poolKey: "gno.land/r/gnoswap/v1/gns:gno.land/r/demo/wugnot:fee", + expectedError: "[GNOSWAP-POSITION-005] invalid input data || invalid fee(fee)", + shouldPanic: true, + }, + { + name: "Fail - missing fee part", + poolKey: "gno.land/r/gnoswap/v1/gns:gno.land/r/demo/wugnot:", + expectedError: "[GNOSWAP-POSITION-005] invalid input data || invalid fee()", + shouldPanic: true, + }, + { + name: "Fail - insufficient parts", + poolKey: "gno.land/r/gnoswap/v1/gns:gno.land/r/demo", + expectedError: "[GNOSWAP-POSITION-005] invalid input data || invalid poolKey(gno.land/r/gnoswap/v1/gns:gno.land/r/demo)", + shouldPanic: true, + }, + { + name: "Success - valid poolKey", + poolKey: "gno.land/r/gnoswap/v1/gns:gno.land/r/demo/wugnot:500", + expectedPath0: "gno.land/r/gnoswap/v1/gns", + expectedPath1: "gno.land/r/demo/wugnot", + expectedFee: 500, + shouldPanic: false, + }, + { + name: "Success - poolKey with large fee", + poolKey: "gno.land/r/gnoswap/v1/gns:gno.land/r/demo/wugnot:10000", + expectedPath0: "gno.land/r/gnoswap/v1/gns", + expectedPath1: "gno.land/r/demo/wugnot", + expectedFee: 10000, + shouldPanic: false, + }, + { + name: "Success - poolKey with minimal fee", + poolKey: "gno.land/r/gnoswap/v1/gns:gno.land/r/demo/wugnot:1", + expectedPath0: "gno.land/r/gnoswap/v1/gns", + expectedPath1: "gno.land/r/demo/wugnot", + expectedFee: 1, + shouldPanic: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.shouldPanic { + assertPanic(t, tc.expectedError, func() { + splitOf(tc.poolKey) + }) + } else { + gotToken0, gotToken1, gotFee := splitOf(tc.poolKey) + uassert.Equal(t, tc.expectedPath0, gotToken0, "Token0 mismatch") + uassert.Equal(t, tc.expectedPath1, gotToken1, "Token1 mismatch") + uassert.Equal(t, tc.expectedFee, gotFee, "Fee mismatch") + } + }) + } +} diff --git a/contract/r/gnoswap/protocol_fee/errors.gno b/contract/r/gnoswap/protocol_fee/errors.gno new file mode 100644 index 000000000..bc8e9fa7f --- /dev/null +++ b/contract/r/gnoswap/protocol_fee/errors.gno @@ -0,0 +1,18 @@ +package protocol_fee + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-PROTOCOL_FEE-001] caller has no permission") + errInvalidPct = errors.New("[GNOSWAP-PROTOCOL_FEE-002] invalid percentage") + errInvalidAmount = errors.New("[GNOSWAP-PROTOCOL_FEE-003] invalid amount") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/protocol_fee/gno.mod b/contract/r/gnoswap/protocol_fee/gno.mod new file mode 100644 index 000000000..4bac85f15 --- /dev/null +++ b/contract/r/gnoswap/protocol_fee/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/protocol_fee diff --git a/contract/r/gnoswap/protocol_fee/protocol_fee.gno b/contract/r/gnoswap/protocol_fee/protocol_fee.gno new file mode 100644 index 000000000..e6b94ad21 --- /dev/null +++ b/contract/r/gnoswap/protocol_fee/protocol_fee.gno @@ -0,0 +1,146 @@ +package protocol_fee + +import ( + "std" + "strconv" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +var ( + devOpsPct uint64 = 0 // 0% + gnsToDevOps uint64 + gnsToGovStaker uint64 + + accuToGovStaker = avl.NewTree() +) + +// DistributeProtocolFee distributes the protocol fee to devOps and gov/staker. +func DistributeProtocolFee() { + assertOnlyNotHalted() + + tokens := common.ListRegisteredTokens() + if len(tokens) == 0 { + return + } + + for _, token := range tokens { + // default distribute protocol fee percent + // govStaker 100% + // ... + + balance := common.BalanceOf(token, consts.PROTOCOL_FEE_ADDR) + if balance > 0 { + println("balance of", token, ":", balance) + toDevOps := balance * devOpsPct / 10000 // default 0% + toGovStaker := balance - toDevOps // default 100% + + if token == consts.GNS_PATH { + gnsToDevOps = toDevOps + gnsToGovStaker = toGovStaker + } + + addAccuToGovStaker(token, toGovStaker) + + tokenTeller := common.GetTokenTeller(token) + if toDevOps > 0 { + tokenTeller.Transfer(consts.DEV_OPS, toDevOps) + } + + if toGovStaker > 0 { + tokenTeller.Transfer(consts.GOV_STAKER_ADDR, toGovStaker) + } + } + } +} + +func GetDevOpsPct() uint64 { + return devOpsPct +} + +// SetDevOpsPctByAdmin sets the devOpsPct. +func SetDevOpsPctByAdmin(pct uint64) { + caller := std.PrevRealm().Addr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } + + setDevOpsPct(pct) +} + +// SetDevOpsPct sets the devOpsPct. +// Only governance contract can execute this function via proposal +func SetDevOpsPct(pct uint64) { + caller := std.PrevRealm().Addr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err) + } + + setDevOpsPct(pct) +} + +// setDevOpsPct sets the devOpsPct. +func setDevOpsPct(pct uint64) { + if pct > 10000 { + panic(addDetailToError( + errInvalidPct, + ufmt.Sprintf("pct(%d) should not be bigger than 10000", pct), + )) + } + + prevDevOpsPct := devOpsPct + devOpsPct = pct + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "SetDevOpsPct", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "newPct", strconv.FormatUint(pct, 10), + "prevPct", strconv.FormatUint(prevDevOpsPct, 10), + ) +} + +// GetLastTransferToDevOps returns the last transfer to devOps. +func GetLastTransferToDevOps() uint64 { + return gnsToDevOps +} + +// GetAccuTransferToGovStaker returns the accuToGovStaker. +func GetAccuTransferToGovStaker() *avl.Tree { + return accuToGovStaker +} + +// GetAccuTransferToGovStakerByTokenPath returns the accumulated transfer to gov/staker by token path. +func GetAccuTransferToGovStakerByTokenPath(path string) uint64 { + amountI, exists := accuToGovStaker.Get(path) + if !exists { + return 0 + } + + return amountI.(uint64) +} + +// ClearAccuTransferToGovStaker clears the accuToGovStaker. +// Only gov/staker can execute this function. +func ClearAccuTransferToGovStaker() { + assertOnlyNotHalted() + + caller := std.PrevRealm().Addr() + if err := common.GovStakerOnly(caller); err != nil { + panic(err) + } + + accuToGovStaker = avl.NewTree() +} + +// addAccuToGovStaker adds the amount to the accuToGovStaker by token path. +func addAccuToGovStaker(path string, amount uint64) { + before := GetAccuTransferToGovStakerByTokenPath(path) + after := before + amount + accuToGovStaker.Set(path, after) +} diff --git a/contract/r/gnoswap/protocol_fee/protocol_fee_test.gno b/contract/r/gnoswap/protocol_fee/protocol_fee_test.gno new file mode 100644 index 000000000..c7435a41a --- /dev/null +++ b/contract/r/gnoswap/protocol_fee/protocol_fee_test.gno @@ -0,0 +1,77 @@ +package protocol_fee + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/qux" +) + +var ( + adminAddr = consts.ADMIN + adminRealm = std.NewUserRealm(adminAddr) +) + +func TestDistributeProtocolFee(t *testing.T) { + // admin > protocol_fee + // send qux, bar for testing + std.TestSetRealm(adminRealm) + bar.Transfer(consts.PROTOCOL_FEE_ADDR, 1000) + qux.Transfer(consts.PROTOCOL_FEE_ADDR, 1000) + + uassert.Equal(t, bar.BalanceOf(consts.PROTOCOL_FEE_ADDR), uint64(1000)) + uassert.Equal(t, bar.BalanceOf(consts.DEV_OPS), uint64(0)) + uassert.Equal(t, bar.BalanceOf(consts.GOV_STAKER_ADDR), uint64(0)) + + uassert.Equal(t, qux.BalanceOf(consts.PROTOCOL_FEE_ADDR), uint64(1000)) + uassert.Equal(t, qux.BalanceOf(consts.DEV_OPS), uint64(0)) + uassert.Equal(t, qux.BalanceOf(consts.GOV_STAKER_ADDR), uint64(0)) + + DistributeProtocolFee() + + uassert.Equal(t, bar.BalanceOf(consts.PROTOCOL_FEE_ADDR), uint64(0)) + uassert.Equal(t, bar.BalanceOf(consts.DEV_OPS), uint64(0)) + uassert.Equal(t, bar.BalanceOf(consts.GOV_STAKER_ADDR), uint64(1000)) + + uassert.Equal(t, qux.BalanceOf(consts.PROTOCOL_FEE_ADDR), uint64(0)) + uassert.Equal(t, qux.BalanceOf(consts.DEV_OPS), uint64(0)) + uassert.Equal(t, qux.BalanceOf(consts.GOV_STAKER_ADDR), uint64(1000)) +} + +func TestSetDevOpsPctByAdminNoPermission(t *testing.T) { + dummy := testutils.TestAddress("dummy") + dummyRealm := std.NewUserRealm(dummy) + std.TestSetRealm(dummyRealm) + + uassert.PanicsWithMessage( + t, `caller(g1v36k6mteta047h6lta047h6lta047h6lz7gmv8) has no permission`, func() { SetDevOpsPctByAdmin(123) }, + ) +} + +func TestSetDevOpsPctByAdminInvalidFee(t *testing.T) { + std.TestSetRealm(adminRealm) + + uassert.PanicsWithMessage( + t, + `[GNOSWAP-PROTOCOL_FEE-002] invalid percentage || pct(100001) should not be bigger than 10000`, + func() { + SetDevOpsPctByAdmin(100001) + }, + ) +} + +func TestSetDevOpsPctByAdmin(t *testing.T) { + std.TestSetRealm(adminRealm) + + uassert.Equal(t, GetDevOpsPct(), uint64(0)) + + SetDevOpsPctByAdmin(123) + + uassert.Equal(t, GetDevOpsPct(), uint64(123)) +} diff --git a/contract/r/gnoswap/protocol_fee/token_list_with_amount.gno b/contract/r/gnoswap/protocol_fee/token_list_with_amount.gno new file mode 100644 index 000000000..ea52bd50c --- /dev/null +++ b/contract/r/gnoswap/protocol_fee/token_list_with_amount.gno @@ -0,0 +1,137 @@ +package protocol_fee + +import ( + "std" + "strings" + + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +var ( + tokenListWithAmount = make(map[string]uint64) // tokenPath -> amount +) + +// TokenList returns only the list of token path. +// If positive is true, it returns only the token path with amount > 0. +// If positive is false, it returns all the token path. +func TokenList(positive bool) []string { + tokens := []string{} + + for tokenPath, amount := range tokenListWithAmount { + if positive && amount == 0 { + continue + } + + tokens = append(tokens, tokenPath) + } + + return tokens +} + +// TokenListWithAmount returns the token path and amount. +func TokenListWithAmount() map[string]uint64 { + return tokenListWithAmount +} + +// AddToProtocolFee adds the amount to the tokenListWithAmount +// Only `pool + router + staker` can execute this function. +func AddToProtocolFee(tokenPath string, amount uint64) { + assertOnlyPoolRouterStaker() + tokenListWithAmount[tokenPath] += amount +} + +// ClearTokenListWithAmount clears the tokenListWithAmount. +// only `gov/staker` can execute this function. +func ClearTokenListWithAmount() { + assertOnlyGovStaker() + clearTokenListWithAmount() +} + +// TransferProtocolFee transfers the protocol fee to devOps and gov/staker. +// only `gov/staker` can execute this function. +// It returns list of token with amount has been sent to gov/staker. +func TransferProtocolFee() map[string]uint64 { + assertOnlyGovStaker() + + sentToDevOps := []string{} + sentToGovStaker := []string{} + toReturn := map[string]uint64{} + + for token, amount := range tokenListWithAmount { + balance := common.BalanceOf(token, consts.PROTOCOL_FEE_ADDR) + + // anyone can just send certain grc20 token to `protocol_fee` contract + // therefore, we don't need any guard logic to check whether protocol_fee's xxx token balance is equal to `amount` + // however, amount always should be less than or equal to balance + if amount > balance { + panic(addDetailToError( + errInvalidAmount, + ufmt.Sprintf("amount: %d should be less than or equal to balance: %d", amount, balance), + )) + } + + if amount > 0 { + toDevOps := balance * devOpsPct / 10000 // default 0% + toGovStaker := balance - toDevOps // default 100% + + tokenTeller := common.GetTokenTeller(token) + if toDevOps > 0 { + tokenTeller.Transfer(consts.DEV_OPS, toDevOps) + sentToDevOps = append(sentToDevOps, makeEventString(token, toDevOps)) + } + + if toGovStaker > 0 { + tokenTeller.Transfer(consts.GOV_STAKER_ADDR, toGovStaker) + sentToGovStaker = append(sentToGovStaker, makeEventString(token, toGovStaker)) + + toReturn[token] = toGovStaker + } + } + } + + clearTokenListWithAmount() + + prevAddr, prevRealm := getPrev() + std.Emit( + "TransferProtocolFee", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "toDevOps", strings.Join(sentToDevOps, ","), + "toGovStaker", strings.Join(sentToGovStaker, ","), + ) + + return toReturn +} + +// clearTokenListWithAmount clears the tokenListWithAmount. +func clearTokenListWithAmount() { + tokenListWithAmount = map[string]uint64{} +} + +// assertOnlyPoolRouterStaker panics if the caller is not the pool, router, or staker contract. +func assertOnlyPoolRouterStaker() { + caller := std.PrevRealm().Addr() + + poolOnlyErr := common.PoolOnly(caller) + routerOnlyErr := common.RouterOnly(caller) + stakerOnlyErr := common.StakerOnly(caller) + + if poolOnlyErr != nil && routerOnlyErr != nil && stakerOnlyErr != nil { + panic(errNoPermission) + } +} + +// assertOnlyGovStaker panics if the caller is not the gov/staker contract. +func assertOnlyGovStaker() { + caller := std.PrevRealm().Addr() + if err := common.GovStakerOnly(caller); err != nil { + panic(err.Error()) + } +} + +func makeEventString(tokenPath string, amount uint64) string { + return tokenPath + "*FEE*" + ufmt.Sprintf("%d", amount) +} diff --git a/contract/r/gnoswap/protocol_fee/token_list_with_amount_test.gno b/contract/r/gnoswap/protocol_fee/token_list_with_amount_test.gno new file mode 100644 index 000000000..500361da7 --- /dev/null +++ b/contract/r/gnoswap/protocol_fee/token_list_with_amount_test.gno @@ -0,0 +1,237 @@ +package protocol_fee + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + _ "gno.land/r/onbloc/baz" + _ "gno.land/r/onbloc/foo" +) + +var ( + dummyRealm = std.NewCodeRealm("gno.land/r/dummy") + + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) +) + +func TestTokenList(t *testing.T) { + tokenListWithAmount = map[string]uint64{ + "gno.land/r/foo": 100, + "gno.land/r/bar": 0, + "gno.land/r/baz": 200, + } + + uassert.Equal(t, len(TokenList(true)), 2) + uassert.Equal(t, len(TokenList(false)), 3) +} + +func TestAddToProtocolFee(t *testing.T) { + tokenListWithAmount = map[string]uint64{} + + tests := []struct { + name string + tokenPath string + amount uint64 + want uint64 + }{ + { + name: "add foo to protocol fee", + tokenPath: "gno.land/r/foo", + amount: 100, + want: 100, + }, + { + name: "add baz to protocol fee", + tokenPath: "gno.land/r/baz", + amount: 50, + want: 50, + }, + { + name: "add more baz to protocol fee", + tokenPath: "gno.land/r/baz", + amount: 10, + want: 60, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + std.TestSetRealm(std.NewCodeRealm(consts.POOL_PATH)) + AddToProtocolFee(test.tokenPath, test.amount) + + uassert.Equal(t, tokenListWithAmount[test.tokenPath], test.want) + }) + } +} + +func TestClearTokenListWithAmount(t *testing.T) { + tokenListWithAmount = map[string]uint64{ + "gno.land/r/foo": 100, + "gno.land/r/baz": 200, + } + + tests := []struct { + name string + prevRealm std.Realm + want map[string]uint64 + shouldPanic bool + panicMsg string + }{ + { + name: "no permission to clear", + prevRealm: dummyRealm, + shouldPanic: true, + panicMsg: "caller(g1lvx5ssxvuz5tttx6uza3myv8xy6w36a46fv7sy) has no permission", + }, + { + name: "clear protocol fee", + prevRealm: std.NewCodeRealm(consts.GOV_STAKER_PATH), + want: map[string]uint64{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + std.TestSetRealm(test.prevRealm) + + if test.shouldPanic { + uassert.PanicsWithMessage(t, test.panicMsg, func() { + ClearTokenListWithAmount() + }) + } else { + ClearTokenListWithAmount() + uassert.Equal(t, len(tokenListWithAmount), len(test.want)) + } + }) + } +} + +func TestTransferProtocolFee(t *testing.T) { + // change devOpsPct to 49% (51% is for gov/staker) + devOpsPct = 4900 + + // send(and add) to protocol fee + transferToProtocolFee(t, "gno.land/r/onbloc/foo", 100) + transferToProtocolFee(t, "gno.land/r/onbloc/baz", 201) + + // call TransferProtocolFee + std.TestSetRealm(std.NewCodeRealm(consts.GOV_STAKER_PATH)) + + devOpsOldFoo := common.BalanceOf("gno.land/r/onbloc/foo", consts.DEV_OPS) + devOpsOldBaz := common.BalanceOf("gno.land/r/onbloc/baz", consts.DEV_OPS) + govStkaerOldFoo := common.BalanceOf("gno.land/r/onbloc/foo", consts.GOV_STAKER_ADDR) + govStkaerOldBaz := common.BalanceOf("gno.land/r/onbloc/baz", consts.GOV_STAKER_ADDR) + uassert.Equal(t, devOpsOldFoo, uint64(0)) + uassert.Equal(t, devOpsOldBaz, uint64(0)) + uassert.Equal(t, govStkaerOldFoo, uint64(0)) + uassert.Equal(t, govStkaerOldBaz, uint64(0)) + + sentToGovStaker := TransferProtocolFee() + // foo 100 + // -> devOps 49% => 49 + // -> gov/staker 51% => 51 + + // baz 201 + // -> devOps 49% => 98 + // -> gov/staker 51% => 103 + + // emitted event + // EVENTS: [{"type":"TransferProtocolFee","attrs":[{"key":"prevAddr","value":"g17e3ykyqk9jmqe2y9wxe9zhep3p7cw56davjqwa"},{"key":"prevRealm","value":"gno.land/r/gnoswap/v1/gov/staker"},{"key":"toDevOps","value":"gno.land/r/onbloc/foo*FEE*49,gno.land/r/onbloc/baz*FEE*98"},{"key":"toGovStaker","value":"gno.land/r/onbloc/foo*FEE*51,gno.land/r/onbloc/baz*FEE*103"}],"pkg_path":"gno.land/r/gnoswap/v1/protocol_fee","func":"TransferProtocolFee"}] + + devOpsNewFoo := common.BalanceOf("gno.land/r/onbloc/foo", consts.DEV_OPS) + devOpsNewBaz := common.BalanceOf("gno.land/r/onbloc/baz", consts.DEV_OPS) + uassert.Equal(t, devOpsNewFoo, uint64(49)) + uassert.Equal(t, devOpsNewBaz, uint64(98)) + + govStkaerNewFoo := common.BalanceOf("gno.land/r/onbloc/foo", consts.GOV_STAKER_ADDR) + govStkaerNewBaz := common.BalanceOf("gno.land/r/onbloc/baz", consts.GOV_STAKER_ADDR) + uassert.Equal(t, govStkaerNewFoo, uint64(51)) + uassert.Equal(t, govStkaerNewBaz, uint64(103)) + + uassert.Equal(t, len(sentToGovStaker), 2) + uassert.Equal(t, sentToGovStaker["gno.land/r/onbloc/foo"], uint64(51)) + uassert.Equal(t, sentToGovStaker["gno.land/r/onbloc/baz"], uint64(103)) +} + +func TestAssertOnlyPoolRouterStaker(t *testing.T) { + tests := []struct { + name string + prevRealm std.Realm + shouldPanic bool + panicMsg string + }{ + { + name: "caller is pool contract", + prevRealm: std.NewCodeRealm(consts.POOL_PATH), + }, + { + name: "caller is router contract", + prevRealm: std.NewCodeRealm(consts.ROUTER_PATH), + }, + { + name: "caller is staker contract", + prevRealm: std.NewCodeRealm(consts.STAKER_PATH), + }, + { + name: "caller is stranger", + prevRealm: dummyRealm, + shouldPanic: true, + panicMsg: "[GNOSWAP-PROTOCOL_FEE-001] caller has no permission", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + std.TestSetRealm(test.prevRealm) + + if test.shouldPanic { + uassert.PanicsWithMessage(t, test.panicMsg, func() { + assertOnlyPoolRouterStaker() + }) + } else { + uassert.NotPanics(t, func() { + assertOnlyPoolRouterStaker() + }) + } + }) + } +} + +func TestAssertOnlyGovStaker(t *testing.T) { + tests := []struct { + name string + prevRealm std.Realm + shouldPanic bool + panicMsg string + }{ + { + name: "caller is gov/staker contract", + prevRealm: std.NewCodeRealm(consts.GOV_STAKER_PATH), + }, + { + name: "caller is stranger", + prevRealm: dummyRealm, + shouldPanic: true, + panicMsg: "[GNOSWAP-PROTOCOL_FEE-001] caller has no permission", + }, + } +} + +// helper +func transferToProtocolFee(t *testing.T, tokenPath string, amount uint64) { + t.Helper() + + std.TestSetRealm(adminRealm) + + tokenTeller := common.GetTokenTeller(tokenPath) + tokenTeller.Transfer(consts.PROTOCOL_FEE_ADDR, amount) + + tokenListWithAmount[tokenPath] += amount +} diff --git a/contract/r/gnoswap/protocol_fee/utils.gno b/contract/r/gnoswap/protocol_fee/utils.gno new file mode 100644 index 000000000..65a8b670a --- /dev/null +++ b/contract/r/gnoswap/protocol_fee/utils.gno @@ -0,0 +1,17 @@ +package protocol_fee + +import ( + "std" + + "gno.land/r/gnoswap/v1/common" +) + +func getPrev() (string, string) { + prev := std.PrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} diff --git a/contract/r/gnoswap/referral/doc.gno b/contract/r/gnoswap/referral/doc.gno new file mode 100644 index 000000000..61d5188fb --- /dev/null +++ b/contract/r/gnoswap/referral/doc.gno @@ -0,0 +1,109 @@ +// Package referral implements a referral system on Gno. It allows +// any authorized caller to register, update, or remove referral +// information. A referral link is defined as a mapping from one +// address (the "user") to another address (the "referrer"). +// +// ## Overview +// +// The referral package is composed of the following components: +// +// 1. **errors.gno**: Provides custom error types (ReferralError) with +// specific error codes and messages. +// 2. **utils.gno**: Contains utility functions for permission checks, +// especially isValidCaller, which ensures only specific, pre-authorized +// callers (e.g., governance or router addresses) can invoke the core +// functions. +// 3. **types.gno**: Defines core constants for event types, attributes, +// and the ReferralKeeper interface, which outlines the fundamental +// methods of the referral system (Register, Update, Remove, etc.). +// 4. **keeper.gno**: Implements the actual business logic behind the +// ReferralKeeper interface. It uses an AVL Tree (avl.Tree) to store +// referral data (address -> referrer). The keeper methods emit events +// when a new referral is registered, updated, or removed. +// 5. **referral.gno**: Exposes a public API (the Referral struct) +// that delegates to the keeper, providing external contracts or +// applications a straightforward way to interact with the system. +// +// ## Workflow +// +// Typical usage of this contract follows these steps: +// +// 1. A caller with valid permissions invokes Register, Update, or Remove +// through the Referral struct. +// 2. The Referral struct forwards the request to the internal keeper +// methods. +// 3. The keeper checks caller permission (via isValidCaller), validates +// addresses, and stores or removes data in the AVL Tree. +// 4. An event is emitted for off-chain or cross-module notifications. +// +// ## Integration with Other Contracts +// +// Other contracts can leverage the referral system in two major ways: +// +// 1. **Direct Calls**: If you wish to directly call this contract, +// instantiate the Referral object (via NewReferral) and invoke its +// methods, assuming you meet the authorized-caller criteria. +// +// 2. **Embedded or Extended**: If you have a complex module that includes +// referral features, import this package and embed a Referral instance +// in your own keeper. This way, you can handle additional validations +// or custom logic before delegating to the existing referral functions. +// +// ## Error Handling +// +// The package defines several error types through ReferralError: +// - `ErrInvalidAddress`: Returned when an address format is invalid +// - `ErrUnauthorized`: Returned when the caller lacks permission +// - `ErrNotFound`: Returned when attempting to get a non-existent referral +// - `ErrZeroAddress`: Returned when attempting operations with zero address +// +// ## Example: Integration with a Staking Contract +// +// Suppose you have a staking contract that wants to reward referrers +// when a new user stakes tokens: +// +// ```go +// +// import ( +// "std" +// "gno.land/r/gnoswap/v1/referral" +// "gno.land/p/demo/mystaking" // example staking contract +// ) +// +// func rewardReferrerOnStake(user std.Address, amount int) { +// // 1) Access the referral system +// r := referral.NewReferral() +// +// // 2) Get the user's referrer +// refAddr, err := r.GetReferral(user) +// if err != nil { +// // handle error or skip if not found +// return +// } +// +// // 3) Reward the referrer +// mystaking.AddReward(refAddr, calculateReward(amount)) +// } +// +// ``` +// +// In this simple example, the staking contract checks if the user has +// a referrer by calling `GetReferral`. If a referrer is found, it then +// calculates a reward based on the staked amount. +// +// ## Limitations and Constraints +// +// - A user can have only one referrer at a time +// - Once a referral is removed, it cannot be automatically restored +// - Only authorized contracts can modify referral relationships +// - Address validation is strict and requires proper Bech32 format +// +// # Notes +// +// - The contract strictly enforces caller restrictions via isValidCaller. +// Make sure to configure it to permit only the addresses or roles that +// should be able to register or update referrals. +// - Zero addresses are treated as a trigger for removing a referral record. +// - The system emits events (register_referral, update_referral, remove_referral) +// which can be consumed by other on-chain or off-chain services. +package referral diff --git a/contract/r/gnoswap/referral/errors.gno b/contract/r/gnoswap/referral/errors.gno new file mode 100644 index 000000000..f0f1fbb48 --- /dev/null +++ b/contract/r/gnoswap/referral/errors.gno @@ -0,0 +1,63 @@ +package referral + +import ( + "gno.land/p/demo/ufmt" +) + +const ( + ErrNone = iota + ErrInvalidAddress + ErrZeroAddress + ErrSelfReferral + ErrUnauthorized + ErrInvalidCaller + ErrCyclicReference + ErrTooManyRequests + ErrNotFound +) + +var errorMessages = map[int]string{ + ErrInvalidAddress: "invalid address format", + ErrZeroAddress: "zero address is not allowed", + ErrSelfReferral: "self referral is not allowed", + ErrUnauthorized: "unauthorized caller", + ErrInvalidCaller: "invalid caller", + ErrCyclicReference: "cyclic reference is not allowed", + ErrTooManyRequests: "too many requests: operations allowed once per 24 hours for each address", + ErrNotFound: "referral not found", +} + +type ReferralError struct { + Code int + Message string + Err error +} + +func (e *ReferralError) Error() string { + // TODO: format error message to follow previous error message format + if e.Err != nil { + return ufmt.Sprintf("code: %d, message: %s, error: %v", e.Code, e.Message, e.Err) + } + return ufmt.Sprintf("code: %d, message: %s", e.Code, e.Message) +} + +func (e *ReferralError) Unwrap() error { + return e.Err +} + +func NewError(code int, args ...interface{}) *ReferralError { + msg := errorMessages[code] + var err error + + if len(args) > 0 { + if lastArg, ok := args[len(args)-1].(error); ok { + err = lastArg + } + } + + return &ReferralError{ + Code: code, + Message: msg, + Err: err, + } +} diff --git a/contract/r/gnoswap/referral/global_keeper.gno b/contract/r/gnoswap/referral/global_keeper.gno new file mode 100644 index 000000000..ac25c237f --- /dev/null +++ b/contract/r/gnoswap/referral/global_keeper.gno @@ -0,0 +1,41 @@ +package referral + +import "std" + +// gReferralKeeper is the global instance of the referral keeper +var gReferralKeeper ReferralKeeper + +func init() { + gReferralKeeper = NewKeeper() +} + +// GetKeeper returns the global instance of the referral keeper +// +// Example: +// +// // In other packages: +// keeper := referral.GetKeeper() +// keeper.register(addr, refAddr) +func GetKeeper() ReferralKeeper { + return gReferralKeeper +} + +func GetReferral(addr string) string { + referral, err := gReferralKeeper.get(std.Address(addr)) + if err != nil { + panic(err) + } + return referral.String() +} + +func HasReferral(addr string) bool { + referral, err := gReferralKeeper.get(std.Address(addr)) + if err != nil { + return false + } + return referral != zeroAddress +} + +func IsEmpty() bool { + return gReferralKeeper.isEmpty() +} diff --git a/contract/r/gnoswap/referral/gno.mod b/contract/r/gnoswap/referral/gno.mod new file mode 100644 index 000000000..3951be4db --- /dev/null +++ b/contract/r/gnoswap/referral/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/referral diff --git a/contract/r/gnoswap/referral/keeper.gno b/contract/r/gnoswap/referral/keeper.gno new file mode 100644 index 000000000..911dd2e8f --- /dev/null +++ b/contract/r/gnoswap/referral/keeper.gno @@ -0,0 +1,177 @@ +package referral + +import ( + "std" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" +) + +const ( + // MinTimeBetweenUpdates represents minimum duration between operations + MinTimeBetweenUpdates int64 = 24 * 60 * 60 +) + +// keeper implements the `ReferralKeeper` interface using an AVL tree for storage. +type keeper struct { + store *avl.Tree + lastOps map[string]int64 +} + +// check interface implementation at compile time +var _ ReferralKeeper = &keeper{} + +// NewKeeper creates and returns a new instance of ReferralKeeper. +// The keeper is initialized with an empty AVL tree for storing referral relationships. +func NewKeeper() ReferralKeeper { + return &keeper{ + store: avl.NewTree(), + lastOps: make(map[string]int64), + } +} + +// register implements the `register` method of the `ReferralKeeper` interface. +// It sets a new referral relationship between the given address and referral address. +func (k *keeper) register(addr, refAddr std.Address) error { + return k.setReferral(addr, refAddr, EventTypeRegister) +} + +// update implements the `update` method of the `ReferralKeeper` interface. +// It updates the referral address for a given address. +func (k *keeper) update(addr, newRefAddr std.Address) error { + return k.setReferral(addr, newRefAddr, EventTypeUpdate) +} + +// setReferral handles the common logic for registering and updating referrals. +// It validates the addresses and caller, then stores the referral relationship. +// +// Note: The current implementation allows circular references, but since it only manages +// simple reference relationships, cycles are not a significant issue. However, when introducing +// a referral-based reward system in the future or adding business logic where cycles could cause problems, +// it will be necessary to implement validation checks. +// +// TODO: need to discuss what values to emit as event +func (k *keeper) setReferral(addr, refAddr std.Address, eventType string) error { + if err := isValidCaller(std.PrevRealm().Addr()); err != nil { + return err + } + + if err := k.validateAddresses(addr, refAddr); err != nil { + return err + } + + if err := k.checkRateLimit(addr.String()); err != nil { + // XXX: because of the weird type-related errors, this is the only + // part where we need to use `ufmt.Errorf` to build error message. + return ufmt.Errorf("too many requests") + } + + if refAddr == zeroAddress { + std.Emit( + EventTypeRemove, + "removedAddress", addr.String(), + ) + return k.remove(addr) + } + + k.store.Set(addr.String(), refAddr.String()) + std.Emit( + eventType, + "myAddress", addr.String(), + "refAddress", refAddr.String(), + ) + + return nil +} + +// validateAddresses checks if the given addresses are valid for referral +func (k *keeper) validateAddresses(addr, refAddr std.Address) error { + if !addr.IsValid() || (refAddr != zeroAddress && !refAddr.IsValid()) { + return NewError(ErrInvalidAddress) + } + if addr == refAddr { + return NewError(ErrSelfReferral) + } + return nil +} + +// remove implements the `remove` method of the `ReferralKeeper` interface. +// It validates the caller and address before removing the referral relationship. +// +// TODO: need to discuss what values to emit as event +func (k *keeper) remove(target std.Address) error { + if err := isValidCaller(std.PrevRealm().Addr()); err != nil { + return err + } + + if !target.IsValid() { + return NewError(ErrInvalidAddress) + } + + if err := k.checkRateLimit(target.String()); err != nil { + return err + } + + k.store.Remove(target.String()) + + // TODO: update event + std.Emit( + EventTypeRemove, + "removedAddress", target.String(), + ) + + return nil +} + +// has implements the `has` method of the `ReferralKeeper` interface. +// It checks if a referral relationship exists for a given address. +func (k *keeper) has(addr std.Address) bool { + _, exists := k.store.Get(addr.String()) + return exists +} + +// get implements the `get` method of the `ReferralKeeper` interface. +// It retrieves the referral address for a given address. +func (k *keeper) get(addr std.Address) (std.Address, error) { + if !addr.IsValid() { + return zeroAddress, NewError(ErrInvalidAddress) + } + + val, ok := k.store.Get(addr.String()) + if !ok { + return zeroAddress, NewError(ErrNotFound) + } + + refAddr, ok := val.(string) + if !ok { + return zeroAddress, NewError(ErrInvalidAddress) + } + + return std.Address(refAddr), nil +} + +func (k *keeper) isEmpty() bool { + empty := true + k.store.Iterate("", "", func(key string, value interface{}) bool { + empty = false + return true // stop iteration on first item + }) + return empty +} + +// checkRateLimit verifies if enough time has passed since the last operation +func (k *keeper) checkRateLimit(addr string) error { + now := time.Now().Unix() + + if lastOpTime, exists := k.lastOps[addr]; exists { + timeSinceLastOp := now - lastOpTime + + if timeSinceLastOp < MinTimeBetweenUpdates { + return NewError(ErrTooManyRequests) + } + } + + k.lastOps[addr] = now + return nil +} diff --git a/contract/r/gnoswap/referral/keeper_test.gno b/contract/r/gnoswap/referral/keeper_test.gno new file mode 100644 index 000000000..fdc9d6cb8 --- /dev/null +++ b/contract/r/gnoswap/referral/keeper_test.gno @@ -0,0 +1,638 @@ +package referral + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" +) + +const STRESS_TEST_NUM = 10000 // arbitrary number + +var ( + validAddr1 = testutils.TestAddress("valid1") + validAddr2 = testutils.TestAddress("valid2") + invalidAddr = testutils.TestAddress("invalid") +) + +// time mocking +var currentTime int64 = time.Now().Unix() + +func mockTimeNow() time.Time { + return time.Unix(currentTime, 0) +} + +func setupKeeper() *keeper { return NewKeeper().(*keeper) } + +func mockValidCaller() func() { + origCaller := std.GetOrigCaller() + std.TestSetOrigCaller(consts.ROUTER_ADDR) + return func() { + std.TestSetOrigCaller(origCaller) + } +} + +func TestRegister(t *testing.T) { + tests := []struct { + name string + addr std.Address + refAddr std.Address + setupCaller func() func() + wantErr bool + errCode int + }{ + { + name: "valid registration", + addr: validAddr1, + refAddr: validAddr2, + setupCaller: mockValidCaller, + wantErr: false, + }, + { + name: "unauthorized caller", + addr: validAddr1, + refAddr: validAddr2, + setupCaller: func() func() { + origCaller := std.GetOrigCaller() + std.TestSetOrigCaller(std.Address("unauthorized")) + return func() { + std.TestSetOrigCaller(origCaller) + } + }, + wantErr: true, + errCode: ErrUnauthorized, + }, + { + name: "self referral", + addr: validAddr1, + refAddr: validAddr1, + setupCaller: mockValidCaller, + wantErr: true, + errCode: ErrSelfReferral, + }, + { + name: "zero address referral", + addr: validAddr1, + refAddr: zeroAddress, + setupCaller: mockValidCaller, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + cleanup := tt.setupCaller() + defer cleanup() + + err := k.register(tt.addr, tt.refAddr) + + if tt.wantErr { + if err == nil { + t.Errorf("register() error = nil, wantErr %v", tt.wantErr) + return + } + referralErr, ok := err.(*ReferralError) + if !ok { + t.Errorf("register() error is not ReferralError type") + return + } + if referralErr.Code != tt.errCode { + t.Errorf("register() error code = %v, want %v", referralErr.Code, tt.errCode) + } + } else if err != nil { + t.Errorf("register() unexpected error = %v", err) + } + }) + } +} + +func TestUpdate(t *testing.T) { + tests := []struct { + name string + addr std.Address + refAddr std.Address + setupState func(*keeper) + setupCaller func() func() + wantErr bool + errCode int + }{ + { + name: "valid update", + addr: validAddr1, + refAddr: validAddr2, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), "old_ref_addr") + }, + setupCaller: mockValidCaller, + wantErr: false, + }, + { + name: "update non-existent referral", + addr: validAddr1, + refAddr: validAddr2, + setupState: func(k *keeper) {}, + setupCaller: mockValidCaller, + wantErr: false, + }, + { + name: "update to self referral", + addr: validAddr1, + refAddr: validAddr1, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + setupCaller: mockValidCaller, + wantErr: true, + errCode: ErrSelfReferral, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + cleanup := tt.setupCaller() + defer cleanup() + + err := k.update(tt.addr, tt.refAddr) + + if tt.wantErr { + if err == nil { + t.Errorf("update() error = nil, wantErr %v", tt.wantErr) + return + } + referralErr, ok := err.(*ReferralError) + if !ok { + t.Errorf("update() error is not ReferralError type") + return + } + if referralErr.Code != tt.errCode { + t.Errorf("update() error code = %v, want %v", referralErr.Code, tt.errCode) + } + } else if err != nil { + t.Errorf("update() unexpected error = %v", err) + } + }) + } +} + +func TestGet(t *testing.T) { + tests := []struct { + name string + addr std.Address + setupState func(*keeper) + wantAddr std.Address + wantErr bool + errCode int + }{ + { + name: "get existing referral", + addr: validAddr1, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + wantAddr: validAddr2, + wantErr: false, + }, + { + name: "get non-existent referral", + addr: validAddr1, + setupState: func(k *keeper) {}, + wantAddr: zeroAddress, + wantErr: true, + errCode: ErrNotFound, + }, + { + name: "get with invalid address", + addr: invalidAddr, + setupState: func(k *keeper) {}, + wantAddr: zeroAddress, + wantErr: true, + errCode: ErrNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + + gotAddr, err := k.get(tt.addr) + + if tt.wantErr { + if err == nil { + t.Errorf("get() error = nil, wantErr %v", tt.wantErr) + return + } + referralErr, ok := err.(*ReferralError) + if !ok { + t.Errorf("get() error is not ReferralError type") + return + } + if referralErr.Code != tt.errCode { + t.Errorf("get() error code = %v, want %v", referralErr.Code, tt.errCode) + } + } else { + if err != nil { + t.Errorf("get() unexpected error") + return + } + if gotAddr != tt.wantAddr { + t.Errorf("get() gotAddr = %v, want %v", gotAddr, tt.wantAddr) + } + } + }) + } +} + +func TestHas(t *testing.T) { + tests := []struct { + name string + addr std.Address + setupState func(*keeper) + want bool + }{ + { + name: "has existing referral", + addr: validAddr1, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + want: true, + }, + { + name: "does not have referral", + addr: validAddr1, + setupState: func(k *keeper) {}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + + if got := k.has(tt.addr); got != tt.want { + t.Errorf("has() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRemove(t *testing.T) { + tests := []struct { + name string + addr std.Address + setupState func(*keeper) + setupCaller func() func() + wantErr bool + errCode int + }{ + { + name: "remove existing referral", + addr: validAddr1, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + setupCaller: mockValidCaller, + wantErr: false, + }, + { + name: "remove non-existent referral", + addr: validAddr1, + setupState: func(k *keeper) {}, + setupCaller: mockValidCaller, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + cleanup := tt.setupCaller() + defer cleanup() + + err := k.remove(tt.addr) + + if tt.wantErr { + if err == nil { + t.Errorf("remove() error = nil, wantErr %v", tt.wantErr) + return + } + referralErr, ok := err.(*ReferralError) + if !ok { + t.Errorf("remove() error is not ReferralError type") + return + } + if referralErr.Code != tt.errCode { + t.Errorf("remove() error code = %v, want %v", referralErr.Code, tt.errCode) + } + } else if err != nil { + t.Errorf("remove() unexpected error = %v", err) + } + + if k.has(tt.addr) { + t.Errorf("remove() referral still exists after removal") + } + }) + } +} + +func TestUpdateNonExistentReferral(t *testing.T) { + k := setupKeeper() + cleanup := mockValidCaller() + defer cleanup() + + err := k.update(validAddr1, validAddr2) + if err != nil { + t.Errorf("update() for non-existent referral failed: %v", err) + } + + refAddr, err := k.get(validAddr1) + if err != nil { + t.Errorf("get() after update failed: %v", err) + } + if refAddr != validAddr2 { + t.Errorf("got refAddr = %v, want %v", refAddr, validAddr2) + } +} + +func TestReferralCycles(t *testing.T) { + k := setupKeeper() + cleanup := mockValidCaller() + defer cleanup() + + addr1 := testutils.TestAddress("cycle1") + addr2 := testutils.TestAddress("cycle2") + addr3 := testutils.TestAddress("cycle3") + + // A -> B -> C + err := k.register(addr1, addr2) + if err != nil { + t.Fatalf("Failed to register addr1->addr2: %v", err) + } + + err = k.register(addr2, addr3) + if err != nil { + t.Fatalf("Failed to register addr2->addr3: %v", err) + } + + // reference cycle: C -> A + err = k.register(addr3, addr1) + if err != nil { + t.Fatalf("Failed to register addr3->addr1: %v", err) + } + + refAddr, _ := k.get(addr1) + if refAddr != addr2 { + t.Error("addr1's referral should be addr2") + } + + refAddr, _ = k.get(addr2) + if refAddr != addr3 { + t.Error("addr2's referral should be addr3") + } + + refAddr, _ = k.get(addr3) + if refAddr != addr1 { + t.Error("addr3's referral should be addr1") + } +} + +func TestStress(t *testing.T) { + t.Skip("Skipping stress test") + + k := setupKeeper() + cleanup := mockValidCaller() + defer cleanup() + + addresses := make([]std.Address, STRESS_TEST_NUM) + + for i := 0; i < STRESS_TEST_NUM; i++ { + addresses[i] = testutils.TestAddress(ufmt.Sprintf("addr%d", i)) + } + + for i := 0; i < STRESS_TEST_NUM; i++ { + err := k.register(addresses[i], addresses[(i+1)%STRESS_TEST_NUM]) + if err != nil { + t.Fatalf("Registration failed at index %d: %v", i, err) + } + + err = k.update(addresses[i], addresses[(i+2)%STRESS_TEST_NUM]) + if err != nil { + t.Fatalf("Update failed at index %d: %v", i, err) + } + + // remove some addresses + if i%3 == 0 { + err = k.remove(addresses[i]) + if err != nil { + t.Fatalf("Remove failed at index %d: %v", i, err) + } + } + + // check data consistency + if i%1000 == 0 { + for j := 0; j <= i; j++ { + if j%3 == 0 { + // check removed address + if k.has(addresses[j]) { + t.Errorf("Removed address still exists at index %d", j) + } + } else { + // check registered address + refAddr, err := k.get(addresses[j]) + if err != nil { + t.Errorf("Failed to get referral at index %d: %v", j, err) + } + expectedAddr := addresses[(j+2)%STRESS_TEST_NUM] + if refAddr != expectedAddr { + t.Errorf("Incorrect referral at index %d", j) + } + } + } + } + } +} + +func TestIsEmpty(t *testing.T) { + tests := []struct { + name string + setupState func(*keeper) + want bool + }{ + { + name: "new keeper must empty", + setupState: func(k *keeper) {}, + want: true, + }, + { + name: "keeper with data must not be empty", + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + want: false, + }, + { + name: "keeper with all data removed must be empty", + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + k.store.Remove(validAddr1.String()) + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + + if got := k.isEmpty(); got != tt.want { + t.Errorf("isEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + +/* Getter Tests */ + +func TestGetReferral(t *testing.T) { + tests := []struct { + name string + addr string + setupState func() + want string + shouldPanic bool + }{ + { + name: "retrieve existing referral", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + cleanup := mockValidCaller() + defer cleanup() + gReferralKeeper.register(validAddr1, validAddr2) + }, + want: validAddr2.String(), + shouldPanic: false, + }, + { + name: "retrieve non-existent referral", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + }, + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Error("GetReferral() function did not panic") + } + }() + } + + got := GetReferral(tt.addr) + if !tt.shouldPanic && got != tt.want { + t.Errorf("GetReferral() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHasReferral(t *testing.T) { + tests := []struct { + name string + addr string + setupState func() + want bool + }{ + { + name: "referral exists", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + cleanup := mockValidCaller() + defer cleanup() + gReferralKeeper.register(validAddr1, validAddr2) + }, + want: true, + }, + { + name: "referral does not exist", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + }, + want: false, + }, + { + name: "zero address referral", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + cleanup := mockValidCaller() + defer cleanup() + gReferralKeeper.register(validAddr1, zeroAddress) + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + if got := HasReferral(tt.addr); got != tt.want { + t.Errorf("HasReferral() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGlobalIsEmpty(t *testing.T) { + tests := []struct { + name string + setupState func() + want bool + }{ + { + name: "new global keeper must be empty", + setupState: func() { + gReferralKeeper = NewKeeper() + }, + want: true, + }, + { + name: "keeper with data must not be empty", + setupState: func() { + gReferralKeeper = NewKeeper() + cleanup := mockValidCaller() + defer cleanup() + gReferralKeeper.register(validAddr1, validAddr2) + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + if got := IsEmpty(); got != tt.want { + t.Errorf("IsEmpty() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/contract/r/gnoswap/referral/referral.gno b/contract/r/gnoswap/referral/referral.gno new file mode 100644 index 000000000..4b30d6a8e --- /dev/null +++ b/contract/r/gnoswap/referral/referral.gno @@ -0,0 +1,54 @@ +package referral + +import ( + "std" +) + +// Referral provides the main interface for managing referral relationships. +// It encapsulates a `ReferralKeeper` instance which handles the actual storage +// and action of referral functionality. +type Referral struct { + keeper ReferralKeeper +} + +// NewReferral creates and returns a new instance of `Referral`. +// it initializes the underlying keeper with a new storage instance. +func NewReferral() *Referral { + return &Referral{ + keeper: NewKeeper(), + } +} + +// Register creates a new referral relationship between addr and refAddr. +// It validates both addresses and ensures the caller has the required permissions. +// Returns an error if the addresses are invalid or if the caller is unauthorized. +func (r *Referral) Register(addr, refAddr std.Address) error { + return r.keeper.register(addr, refAddr) +} + +// Update modifies an existing referral relationship for the given address. +// The new referral address will replace any existing referral. +// Returns an error if the addresses are invalid or if the caller is unauthorized. +func (r *Referral) Update(addr, newAddr std.Address) error { + return r.keeper.update(addr, newAddr) +} + +// Remove deletes the referral relationship for the given address. +// If no referral exists for the address, the operation is a no-op. +// Returns an error if the address is invalid or if the caller is unauthorized. +func (r *Referral) Remove(addr std.Address) error { + return r.keeper.remove(addr) +} + +// Has checks if a referral relationship exists for the given address. +// Returns true if a referral exists, false otherwise. +func (r *Referral) Has(addr std.Address) bool { + return r.keeper.has(addr) +} + +// Get retrieves the referral address for the given address. +// Returns the referral address and nil error if found. +// Returns zeroAddress and an error if the address is invalid or no referral exists. +func (r *Referral) Get(addr std.Address) (std.Address, error) { + return r.keeper.get(addr) +} diff --git a/contract/r/gnoswap/referral/type.gno b/contract/r/gnoswap/referral/type.gno new file mode 100644 index 000000000..6f64651dc --- /dev/null +++ b/contract/r/gnoswap/referral/type.gno @@ -0,0 +1,38 @@ +package referral + +import "std" + +// zeroAddress represents an empty address used for validation and comparison. +var zeroAddress = std.Address("") + +// Event types for each referral actions. +const ( + EventTypeRegister = "RegisterReferral" + EventTypeUpdate = "UpdateReferral" + EventTypeRemove = "RemoveReferral" +) + +// ReferralKeeper defines the interface for managing referral relationships. +type ReferralKeeper interface { + // register creates a new refferal relationship betwwen address and referral address. + // returns an error if the addresses are invalid or if the caller is not authorized. + register(addr, refAddr std.Address) error + + // update updates the referral address for a given address. + // returns an error if the addresses are invalid or if the caller is not authorized. + update(addr, newRefAddr std.Address) error + + // remove removes the referral relationship for a given address. + // returns an error if the address is invalid or if the caller is not authorized. + remove(addr std.Address) error + + // has checks if a referral relationship exists for a given address. + has(addr std.Address) bool + + // get retrieves the referral address for a given address. + // returns an error if the address is invalid or if the referral relationship does not exist. + get(addr std.Address) (std.Address, error) + + // isEmpty checks if the referral relationship is empty. + isEmpty() bool +} diff --git a/contract/r/gnoswap/referral/utils.gno b/contract/r/gnoswap/referral/utils.gno new file mode 100644 index 000000000..ce7c2d2ef --- /dev/null +++ b/contract/r/gnoswap/referral/utils.gno @@ -0,0 +1,28 @@ +package referral + +import ( + "std" + + "gno.land/p/gnoswap/consts" +) + +// validCallers is a lookup table of addresses that are authorized to modify referral data. +// This includes governance contracts, router, position manager, and staker contracts. +var validCallers = map[std.Address]bool{ + consts.GOV_GOVERNANCE_ADDR: true, + consts.GOV_STAKER_ADDR: true, + consts.ROUTER_ADDR: true, + consts.POSITION_ADDR: true, + consts.STAKER_ADDR: true, + consts.LAUNCHPAD_ADDR: true, +} + +// isValidCaller checks if the given address has permission to modify referral data. +// Only specific pre-authorized addresses defined in validCallers map are allowed to +// register, update, or remove referrals. +func isValidCaller(caller std.Address) error { + if validCallers[caller] { + return nil + } + return NewError(ErrUnauthorized, caller) +} diff --git a/contract/r/gnoswap/router/_helper_test.gno b/contract/r/gnoswap/router/_helper_test.gno new file mode 100644 index 000000000..93355d5e8 --- /dev/null +++ b/contract/r/gnoswap/router/_helper_test.gno @@ -0,0 +1,487 @@ +package router + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/gnoswap/consts" + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/gns" + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + "gno.land/r/onbloc/foo" + "gno.land/r/onbloc/obl" + "gno.land/r/onbloc/qux" +) + +const ( + ugnotDenom string = "ugnot" + ugnotPath string = "ugnot" + wugnotPath string = "gno.land/r/demo/wugnot" + gnsPath string = "gno.land/r/gnoswap/v1/gns" + barPath string = "gno.land/r/onbloc/bar" + bazPath string = "gno.land/r/onbloc/baz" + fooPath string = "gno.land/r/onbloc/foo" + oblPath string = "gno.land/r/onbloc/obl" + quxPath string = "gno.land/r/onbloc/qux" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + maxApprove uint64 = 18446744073709551615 + max_timeout int64 = 9999999999 + + TIER_1 uint64 = 1 + TIER_2 uint64 = 2 + TIER_3 uint64 = 3 + + poolCreationFee = 100_000_000 +) + +const ( + FEE_LOW uint32 = 500 + FEE_MEDIUM uint32 = 3000 + FEE_HIGH uint32 = 10000 +) + +const ( + // define addresses to use in tests + addr01 = testutils.TestAddress("addr01") + addr02 = testutils.TestAddress("addr02") +) + +var ( + user1Addr std.Address = "g1ecely4gjy0yl6s9kt409ll330q9hk2lj9ls3ec" + singlePoolPath = "gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:3000" + singlePoolPath2 = "gno.land/r/onbloc/baz:gno.land/r/onbloc/bar:3000" +) + +var ( + minTick int32 = -887220 + maxTick int32 = 887220 +) + +var ( + admin = consts.ADMIN + adminAddr = admin + alice = testutils.TestAddress("alice") + bob = testutils.TestAddress("bob") + pool = consts.POOL_ADDR + position = consts.POSITION_ADDR + router = consts.ROUTER_ADDR + protocolFee = consts.PROTOCOL_FEE_ADDR + adminRealm = std.NewUserRealm(admin) + posRealm = std.NewCodeRealm(consts.POSITION_PATH) + + // addresses used in tests + addrUsedInTest = []std.Address{addr01, addr02} +) + +func InitialisePoolTest(t *testing.T) { + t.Helper() + + ugnotFaucet(t, admin, 100_000_000_000_000) + ugnotDeposit(t, admin, 100_000_000_000_000) + + std.TestSetOrigCaller(admin) + TokenApprove(t, gnsPath, admin, pool, maxApprove) + CreatePool(t, wugnotPath, gnsPath, fee3000, "79228162514264337593543950336", admin) + + // 2. create position + std.TestSetOrigCaller(alice) + TokenFaucet(t, wugnotPath, alice) + TokenFaucet(t, gnsPath, alice) + TokenApprove(t, wugnotPath, alice, pool, uint64(1000)) + TokenApprove(t, gnsPath, alice, pool, uint64(1000)) +} + +func TokenFaucet(t *testing.T, tokenPath string, to std.Address) { + t.Helper() + std.TestSetOrigCaller(admin) + defaultAmount := uint64(5_000_000_000) + + switch tokenPath { + case wugnotPath: + wugnotTransfer(t, to, defaultAmount) + case gnsPath: + gnsTransfer(t, to, defaultAmount) + case barPath: + barTransfer(t, to, defaultAmount) + case bazPath: + bazTransfer(t, to, defaultAmount) + case fooPath: + fooTransfer(t, to, defaultAmount) + case oblPath: + oblTransfer(t, to, defaultAmount) + case quxPath: + quxTransfer(t, to, defaultAmount) + default: + panic("token not found") + } +} + +func TokenBalance(t *testing.T, tokenPath string, owner std.Address) uint64 { + t.Helper() + switch tokenPath { + case wugnotPath: + return wugnot.BalanceOf(common.AddrToUser(owner)) + case gnsPath: + return gns.BalanceOf(owner) + case barPath: + return bar.BalanceOf(owner) + case bazPath: + return baz.BalanceOf(owner) + case fooPath: + return foo.BalanceOf(owner) + case oblPath: + return obl.BalanceOf(owner) + case quxPath: + return qux.BalanceOf(owner) + default: + panic("token not found") + } +} + +func TokenAllowance(t *testing.T, tokenPath string, owner, spender std.Address) uint64 { + t.Helper() + switch tokenPath { + case wugnotPath: + return wugnot.Allowance(common.AddrToUser(owner), common.AddrToUser(spender)) + case gnsPath: + return gns.Allowance(owner, spender) + case barPath: + return bar.Allowance(owner, spender) + case bazPath: + return baz.Allowance(owner, spender) + case fooPath: + return foo.Allowance(owner, spender) + case oblPath: + return obl.Allowance(owner, spender) + case quxPath: + return qux.Allowance(owner, spender) + default: + panic("token not found") + } +} + +func TokenApprove(t *testing.T, tokenPath string, owner, spender std.Address, amount uint64) { + t.Helper() + switch tokenPath { + case wugnotPath: + wugnotApprove(t, owner, spender, amount) + case gnsPath: + gnsApprove(t, owner, spender, amount) + case barPath: + barApprove(t, owner, spender, amount) + case bazPath: + bazApprove(t, owner, spender, amount) + case fooPath: + fooApprove(t, owner, spender, amount) + case oblPath: + oblApprove(t, owner, spender, amount) + case quxPath: + quxApprove(t, owner, spender, amount) + default: + panic("token not found") + } +} + +func CreatePool(t *testing.T, + token0 string, + token1 string, + fee uint32, + sqrtPriceX96 string, + caller std.Address, +) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(caller)) + poolPath := pl.GetPoolPath(token0, token1, fee) + if !pl.DoesPoolPathExist(poolPath) { + pl.CreatePool(token0, token1, fee, sqrtPriceX96) + sr.SetPoolTierByAdmin(poolPath, TIER_1) + } +} + +func LPTokenStake(t *testing.T, owner std.Address, tokenId uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + sr.StakeToken(tokenId) +} + +func LPTokenUnStake(t *testing.T, owner std.Address, tokenId uint64, unwrap bool) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + sr.UnstakeToken(tokenId, unwrap) +} + +func wugnotApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + wugnot.Approve(common.AddrToUser(spender), amount) +} + +func gnsApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + gns.Approve(spender, amount) +} + +func barApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + bar.Approve(spender, amount) +} + +func bazApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + baz.Approve(spender, amount) +} + +func fooApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + foo.Approve(spender, amount) +} + +func oblApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + obl.Approve(spender, amount) +} + +func quxApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + qux.Approve(spender, amount) +} + +func wugnotTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + wugnot.Transfer(common.AddrToUser(to), amount) +} + +func gnsTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + gns.Transfer(to, amount) +} + +func barTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + bar.Transfer(to, amount) +} + +func bazTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + baz.Transfer(to, amount) +} + +func fooTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + foo.Transfer(to, amount) +} + +func oblTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + obl.Transfer(to, amount) +} + +func quxTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + qux.Transfer(to, amount) +} + +// ---------------------------------------------------------------------------- +// ugnot + +func ugnotTransfer(t *testing.T, from, to std.Address, amount uint64) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(from)) + std.TestSetOrigSend(std.Coins{{ugnotDenom, int64(amount)}}, nil) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(from, to, std.Coins{{ugnotDenom, int64(amount)}}) +} + +func ugnotBalanceOf(t *testing.T, addr std.Address) uint64 { + t.Helper() + + banker := std.GetBanker(std.BankerTypeRealmIssue) + coins := banker.GetCoins(addr) + if len(coins) == 0 { + return 0 + } + + return uint64(coins.AmountOf(ugnotDenom)) +} + +func ugnotMint(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.IssueCoin(addr, denom, amount) + std.TestIssueCoins(addr, std.Coins{{denom, int64(amount)}}) +} + +func ugnotBurn(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.RemoveCoin(addr, denom, amount) +} + +func ugnotFaucet(t *testing.T, to std.Address, amount uint64) { + t.Helper() + faucetAddress := admin + std.TestSetOrigCaller(faucetAddress) + + if ugnotBalanceOf(t, faucetAddress) < amount { + newCoins := std.Coins{{ugnotDenom, int64(amount)}} + ugnotMint(t, faucetAddress, newCoins[0].Denom, newCoins[0].Amount) + std.TestSetOrigSend(newCoins, nil) + } + ugnotTransfer(t, faucetAddress, to, amount) +} + +func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(addr)) + wugnotAddr := consts.WUGNOT_ADDR + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(addr, wugnotAddr, std.Coins{{ugnotDenom, int64(amount)}}) + wugnot.Deposit() +} + +func CreatePoolWithoutFee(t *testing.T) { + std.TestSetRealm(adminRealm) + // set pool create fee to 0 for testing + pl.SetPoolCreationFeeByAdmin(0) + CreatePool(t, barPath, fooPath, fee500, common.TickMathGetSqrtRatioAtTick(0).ToString(), admin) + CreatePool(t, bazPath, fooPath, fee3000, common.TickMathGetSqrtRatioAtTick(0).ToString(), admin) + CreatePool(t, barPath, bazPath, fee3000, common.TickMathGetSqrtRatioAtTick(0).ToString(), admin) + CreatePool(t, barPath, bazPath, fee500, common.TickMathGetSqrtRatioAtTick(0).ToString(), admin) +} + +func CreateSecondPoolWithoutFee(t *testing.T) { + std.TestSetRealm(adminRealm) + pl.SetPoolCreationFeeByAdmin(0) + + CreatePool(t, + bazPath, + quxPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), + admin, + ) +} + +func MakeMintPositionWithoutFee(t *testing.T) (uint64, string, string, string) { + t.Helper() + + // make actual data to test resetting not only position's state but also pool's state + std.TestSetRealm(adminRealm) + + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, bazPath, admin, pool, consts.UINT64_MAX) + + // mint position + return pn.Mint( + barPath, + bazPath, + fee3000, + -887220, + 887220, + "50000", + "50000", + "0", + "0", + max_timeout, + admin, + admin, + ) +} + +func MakeSecondMintPositionWithoutFee(t *testing.T) (uint64, string, string, string) { + t.Helper() + + std.TestSetRealm(adminRealm) + + TokenApprove(t, bazPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, quxPath, admin, pool, consts.UINT64_MAX) + + return pn.Mint( + bazPath, + quxPath, + fee3000, + -887220, + 887220, + "50000", + "50000", + "0", + "0", + max_timeout, + admin, + admin, + ) +} + +func MakeThirdMintPositionWithoutFee(t *testing.T) (uint64, string, string, string) { + t.Helper() + + std.TestSetRealm(adminRealm) + + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, fooPath, admin, pool, consts.UINT64_MAX) + + return pn.Mint( + barPath, + fooPath, + fee500, + -887220, + 887220, + "50000", + "50000", + "0", + "0", + max_timeout, + admin, + admin, + ) +} + +func MakeForthMintPositionWithoutFee(t *testing.T) (uint64, string, string, string) { + t.Helper() + + // make actual data to test resetting not only position's state but also pool's state + std.TestSetRealm(adminRealm) + + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, bazPath, admin, pool, consts.UINT64_MAX) + + // mint position + return pn.Mint( + barPath, + bazPath, + fee500, + -887220, + 887220, + "50000", + "50000", + "0", + "0", + max_timeout, + admin, + admin, + ) +} diff --git a/contract/r/gnoswap/router/base.gno b/contract/r/gnoswap/router/base.gno new file mode 100644 index 000000000..b95f662d5 --- /dev/null +++ b/contract/r/gnoswap/router/base.gno @@ -0,0 +1,156 @@ +package router + +import ( + "std" + "strconv" + "strings" + + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/wugnot" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +const ( + SINGLE_HOP_ROUTE int = 1 + + INITIAL_WUGNOT_BALANCE uint64 = 0 +) + +const ( + POOL_SEPARATOR = "*POOL*" +) + +type RouterOperation interface { + Validate() error + Process() (*SwapResult, error) +} + +func executeSwapOperation(op RouterOperation) (*SwapResult, error) { + if err := op.Validate(); err != nil { + return nil, err + } + + result, err := op.Process() + if err != nil { + return nil, err + } + + return result, nil +} + +type BaseSwapParams struct { + InputToken string + OutputToken string + RouteArr string + QuoteArr string + Deadline int64 +} + +// common swap operation +type baseSwapOperation struct { + routes []string + quotes []string + amountSpecified *i256.Int + userBeforeWugnotBalance uint64 + userWrappedWugnot uint64 +} + +func (op *baseSwapOperation) handleNativeTokenWrapping( + inputToken string, + outputToken string, + swapType SwapType, + specifiedAmount *i256.Int, +) error { + // no native token + if inputToken == consts.GNOT || outputToken == consts.GNOT { + return nil + } + + // save current user's WGNOT amount + op.userBeforeWugnotBalance = wugnot.BalanceOf(common.AddrToUser(std.PrevRealm().Addr())) + + if swapType == ExactIn && inputToken == consts.GNOT { + sent := std.GetOrigSend() + + ugnotSentByUser := uint64(sent.AmountOf("ugnot")) + amountSpecified := specifiedAmount.Uint64() + + if ugnotSentByUser != amountSpecified { + return ufmt.Errorf("ugnot sent by user(%d) is not equal to amountSpecified(%d)", ugnotSentByUser, amountSpecified) + } + + // wrap user's WUGNOT + if ugnotSentByUser > 0 { + wrap(ugnotSentByUser) + } + + op.userWrappedWugnot = ugnotSentByUser + } + + return nil +} + +func (op *baseSwapOperation) validateRouteQuote(quote string, i int) (*i256.Int, error) { + qt, err := strconv.Atoi(quote) + if err != nil { + return nil, ufmt.Errorf("invalid quote(%s) at index(%d)", quote, i) + } + + // calculate amount to swap for this route + toSwap := i256.Zero().Mul(op.amountSpecified, i256.NewInt(int64(qt))) + toSwap = i256.Zero().Div(toSwap, i256.NewInt(100)) + + return toSwap, nil +} + +func (op *baseSwapOperation) processRoutes(swapType SwapType) (*u256.Uint, *u256.Uint, error) { + resultAmountIn := u256.Zero() + resultAmountOut := u256.Zero() + + for i, route := range op.routes { + toSwap, err := op.validateRouteQuote(op.quotes[i], i) + if err != nil { + return nil, nil, err + } + + amountIn, amountOut, err := op.processRoute(route, toSwap, swapType) + if err != nil { + return nil, nil, err + } + + resultAmountIn = new(u256.Uint).Add(resultAmountIn, amountIn) + resultAmountOut = new(u256.Uint).Add(resultAmountOut, amountOut) + } + + return resultAmountIn, resultAmountOut, nil +} + +func (op *baseSwapOperation) processRoute( + route string, + toSwap *i256.Int, + swapType SwapType, +) (*u256.Uint, *u256.Uint, error) { + numHops := strings.Count(route, POOL_SEPARATOR) + 1 + assertHopsInRange(numHops) + + var amountIn, amountOut *u256.Uint + + switch numHops { + case SINGLE_HOP_ROUTE: + amountIn, amountOut = handleSingleSwap(route, toSwap) + default: + amountIn, amountOut = handleMultiSwap(swapType, route, numHops, toSwap) + } + + if amountIn == nil || amountOut == nil { + return nil, nil, ufmt.Errorf("swap failed to process route(%s)", route) + } + + return amountIn, amountOut, nil +} diff --git a/contract/r/gnoswap/router/base_test.gno b/contract/r/gnoswap/router/base_test.gno new file mode 100644 index 000000000..5bcb36e78 --- /dev/null +++ b/contract/r/gnoswap/router/base_test.gno @@ -0,0 +1,211 @@ +package router + +import ( + "errors" + "std" + "testing" + + "gno.land/p/gnoswap/consts" + i256 "gno.land/p/gnoswap/int256" + + "gno.land/p/demo/uassert" +) + +var errDummy = errors.New("dummy error") + +type mockOperation struct { + ValidateErr error + ProcessErr error + Result *SwapResult +} + +func (m *mockOperation) Validate() error { + return m.ValidateErr +} + +func (m *mockOperation) Process() (*SwapResult, error) { + return m.Result, m.ProcessErr +} + +func TestExecuteSwapOperation(t *testing.T) { + tests := []struct { + name string + operation RouterOperation + expectError bool + }{ + { + name: "success case", + operation: &mockOperation{ + ValidateErr: nil, + ProcessErr: nil, + Result: &SwapResult{}, + }, + expectError: false, + }, + { + name: "validate error", + operation: &mockOperation{ + ValidateErr: errDummy, + ProcessErr: nil, + Result: &SwapResult{}, + }, + expectError: true, + }, + { + name: "process error", + operation: &mockOperation{ + ValidateErr: nil, + ProcessErr: errDummy, + Result: nil, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := executeSwapOperation(tt.operation) + if tt.expectError && err == nil { + t.Errorf("expected an error but got nil (test case: %s)", tt.name) + } + if !tt.expectError && err != nil { + t.Errorf("unexpected error: %v (test case: %s)", err, tt.name) + } + if !tt.expectError && result == nil { + t.Errorf("expected non-nil result but got nil (test case: %s)", tt.name) + } + }) + } +} + +func TestHandleNativeTokenWrapping(t *testing.T) { + tests := []struct { + name string + inputToken string + outputToken string + swapType SwapType + specifiedAmount *i256.Int + sentAmount int64 + expectError bool + }{ + { + name: "Pass: non-GNOT token swap", + inputToken: "token1", + outputToken: "token2", + swapType: ExactIn, + specifiedAmount: i256.NewInt(100), + sentAmount: 0, + expectError: false, + }, + { + name: "Pass: GNOT -> WGNOT exact amount", + inputToken: consts.GNOT, + outputToken: "token2", + swapType: ExactIn, + specifiedAmount: i256.NewInt(1000), + sentAmount: 1000, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op := &baseSwapOperation{} + + testCoins := std.Coins{{"ugnot", tt.sentAmount}} + std.TestSetOrigSend(testCoins, std.Coins{}) + + err := op.handleNativeTokenWrapping( + tt.inputToken, + tt.outputToken, + tt.swapType, + tt.specifiedAmount, + ) + + if tt.expectError && err == nil { + t.Errorf("expected an error but got nil") + } + if !tt.expectError && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestValidateRouteQuote(t *testing.T) { + op := &baseSwapOperation{ + amountSpecified: i256.NewInt(1000), + } + + tests := []struct { + name string + quote string + index int + expectError bool + expected *i256.Int + }{ + { + name: "Pass: valid quote - 100%", + quote: "100", + index: 0, + expectError: false, + expected: i256.NewInt(1000), // 1000 * 100 / 100 = 1000 + }, + { + name: "Pass: valid quote - 50%", + quote: "50", + index: 0, + expectError: false, + expected: i256.NewInt(500), // 1000 * 50 / 100 = 500 + }, + { + name: "Fail: invalid quote - string", + quote: "invalid", + index: 0, + expectError: true, + expected: nil, + }, + { + name: "Fail: invalid quote - empty string", + quote: "", + index: 0, + expectError: true, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := op.validateRouteQuote(tt.quote, tt.index) + if tt.expectError { + uassert.Error(t, err) + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result.Cmp(tt.expected) != 0 { + t.Errorf("expected %v but got %v", tt.expected, result) + } + } + }) + } +} + +func TestProcessRoute(t *testing.T) { + std.TestSetRealm(adminRealm) + op := &baseSwapOperation{} + + t.Run("Single hop route", func(t *testing.T) { + CreatePoolWithoutFee(t) + MakeThirdMintPositionWithoutFee(t) + route := "gno.land/r/onbloc/foo:gno.land/r/onbloc/bar:500" + toSwap := i256.NewInt(1000) + swapType := ExactIn + + amountIn, amountOut, err := op.processRoute(route, toSwap, swapType) + + uassert.Equal(t, err, nil) + uassert.Equal(t, amountIn.ToString(), "1000") + uassert.Equal(t, amountOut.ToString(), "979") + }) +} diff --git a/contract/r/gnoswap/router/doc.gno b/contract/r/gnoswap/router/doc.gno new file mode 100644 index 000000000..3095e5e60 --- /dev/null +++ b/contract/r/gnoswap/router/doc.gno @@ -0,0 +1,2 @@ +// Package router connects different pools for seamless token swaps. +package router diff --git a/contract/r/gnoswap/router/errors.gno b/contract/r/gnoswap/router/errors.gno new file mode 100644 index 000000000..48cdccf33 --- /dev/null +++ b/contract/r/gnoswap/router/errors.gno @@ -0,0 +1,28 @@ +package router + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-ROUTER-001] caller has no permission") + errSlippage = errors.New("[GNOSWAP-ROUTER-002] slippage check failed") + errInvalidRoutesAndQuotes = errors.New("[GNOSWAP-ROUTER-003] invalid routes and quotes") + errExpired = errors.New("[GNOSWAP-ROUTER-004] transaction expired") + errInvalidInput = errors.New("[GNOSWAP-ROUTER-005] invalid input data") + errInvalidPoolFeeTier = errors.New("[GNOSWAP-ROUTER-006] invalid pool fee tier") + errInvalidSwapFee = errors.New("[GNOSWAP-ROUTER-007] invalid swap fee") + errInvalidSwapType = errors.New("[GNOSWAP-ROUTER-008] invalid swap type") + errInvalidPoolPath = errors.New("[GNOSWAP-ROUTER-009] invalid pool path") + errWrapUnwrap = errors.New("[GNOSWAP-ROUTER-010] wrap, unwrap failed") + errWugnotMinimum = errors.New("[GNOSWAP-ROUTER-011] less than minimum amount ") + errQuoteParser = errors.New("[GNOSWAP-ROUTER-012] quote parse failed") + errHopsOutOfRange = errors.New("[GNOSWAP-ROUTER-013] number of hops must be 1~3") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/router/exact_in.gno b/contract/r/gnoswap/router/exact_in.gno new file mode 100644 index 000000000..f28c98c12 --- /dev/null +++ b/contract/r/gnoswap/router/exact_in.gno @@ -0,0 +1,138 @@ +package router + +import ( + "std" + + "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +type ExactInSwapOperation struct { + baseSwapOperation + params ExactInParams +} + +func NewExactInSwapOperation(pp ExactInParams) *ExactInSwapOperation { + return &ExactInSwapOperation{ + params: pp, + baseSwapOperation: baseSwapOperation{ + userWrappedWugnot: INITIAL_WUGNOT_BALANCE, + }, + } +} + +func ExactInSwapRoute( + inputToken string, + outputToken string, + amountIn string, + RouteArr string, + quoteArr string, + amountOutMin string, + deadline int64, +) (string, string) { + checkDeadline(deadline) + commonSwapSetup() + + baseParams := BaseSwapParams{ + InputToken: inputToken, + OutputToken: outputToken, + RouteArr: RouteArr, + QuoteArr: quoteArr, + } + + pp := NewExactInParams( + baseParams, + amountIn, + amountOutMin, + ) + + op := NewExactInSwapOperation(pp) + + result, err := executeSwapOperation(op) + if err != nil { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("invalid ExactInSwapOperation: %s", err.Error()), + )) + } + + inputAmount, outputAmount := finalizeSwap( + pp.InputToken, + pp.OutputToken, + result.AmountIn, + result.AmountOut, + ExactIn, + u256.MustFromDecimal(pp.AmountOutMin), + op.userBeforeWugnotBalance, + op.userWrappedWugnot, + result.AmountSpecified.Abs(), + ) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "ExactInSwap", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "input", pp.InputToken, + "output", pp.OutputToken, + "exactAmount", amountIn, + "route", pp.RouteArr, + "quote", pp.QuoteArr, + "resultInputAmount", inputAmount, + "resultOutputAmount", outputAmount, + ) + + return inputAmount, outputAmount +} + +func (op *ExactInSwapOperation) Validate() error { + amountIn := i256.MustFromDecimal(op.params.AmountIn) + if amountIn.IsZero() || amountIn.IsNeg() { + return ufmt.Errorf("invalid amountIn(%s), must be positive", amountIn.ToString()) + } + + // when `SwapType` is `ExactIn`, assign `amountSpecified` the `amountIn` + // obtained from above. + op.amountSpecified = amountIn + + routes, quotes, err := tryParseRoutes(op.params.RouteArr, op.params.QuoteArr) + if err != nil { + return err + } + + op.routes = routes + op.quotes = quotes + + return nil +} + +func (op *ExactInSwapOperation) Process() (*SwapResult, error) { + if err := op.handleNativeTokenWrapping(); err != nil { + return nil, err + } + + resultAmountIn, resultAmountOut, err := op.processRoutes(ExactIn) + if err != nil { + return nil, err + } + + return &SwapResult{ + AmountIn: resultAmountIn, + AmountOut: resultAmountOut, + Routes: op.routes, + Quotes: op.quotes, + AmountSpecified: op.amountSpecified, + }, nil +} + +func (op *ExactInSwapOperation) handleNativeTokenWrapping() error { + return op.baseSwapOperation.handleNativeTokenWrapping( + op.params.InputToken, + op.params.OutputToken, + ExactIn, + op.amountSpecified, + ) +} diff --git a/contract/r/gnoswap/router/exact_in_test.gno b/contract/r/gnoswap/router/exact_in_test.gno new file mode 100644 index 000000000..f1396ff11 --- /dev/null +++ b/contract/r/gnoswap/router/exact_in_test.gno @@ -0,0 +1,167 @@ +package router + +import ( + "std" + "testing" + "time" + + "gno.land/p/gnoswap/consts" + + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" +) + +func TestExactInSwapRouteOperation_Validate(t *testing.T) { + tests := []struct { + name string + inputToken string + outputToken string + amountIn string + amountOutMin string + routeArr string + quoteArr string + wantErr bool + errMsg string + }{ + { + name: "Pass: single pool path", + inputToken: barPath, + outputToken: bazPath, + amountIn: "100", + amountOutMin: "90", + routeArr: singlePoolPath, + quoteArr: "100", + wantErr: false, + }, + { + name: "Fail: amountIn is 0", + inputToken: barPath, + outputToken: bazPath, + amountIn: "0", + amountOutMin: "100", + routeArr: singlePoolPath, + quoteArr: "100", + wantErr: true, + errMsg: "invalid amountIn(0), must be positive", + }, + { + name: "Fail: amountIn is negative", + inputToken: barPath, + outputToken: bazPath, + amountIn: "-100", + amountOutMin: "10", + routeArr: singlePoolPath, + quoteArr: "100", + wantErr: true, + errMsg: "invalid amountIn(-100), must be positive", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + baseParams := BaseSwapParams{ + InputToken: tt.inputToken, + OutputToken: tt.outputToken, + RouteArr: tt.routeArr, + QuoteArr: tt.quoteArr, + } + + pp := NewExactInParams( + baseParams, + tt.amountIn, + tt.amountOutMin, + ) + + op := NewExactInSwapOperation(pp) + err := op.Validate() + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + return + } + if err.Error() != tt.errMsg { + t.Errorf("expected error message %q but got %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func TestExactInSwapRoute(t *testing.T) { + CreatePoolWithoutFee(t) + MakeMintPositionWithoutFee(t) + + std.TestSkipHeights(100) + user1Realm := std.NewUserRealm(user1Addr) + std.TestSetRealm(user1Realm) + + tests := []struct { + name string + setup func() + inputToken string + outputToken string + amountIn string + routeArr string + quoteArr string + amountOutMin string + wantErr bool + }{ + { + name: "BAR -> BAZ", + setup: func() { + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + TokenFaucet(t, barPath, user1Addr) + }, + inputToken: barPath, + outputToken: bazPath, + amountIn: "100", + routeArr: singlePoolPath, + quoteArr: "100", + amountOutMin: "85", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + bar.Approve(consts.POOL_ADDR, maxApprove) + baz.Approve(consts.POOL_ADDR, maxApprove) + if tt.setup != nil { + tt.setup() + } + + defer func() { + if r := recover(); r != nil { + if !tt.wantErr { + t.Errorf("ExactInSwapRoute() panic = %v", r) + } + } + }() + + amountIn, amountOut := ExactInSwapRoute( + tt.inputToken, + tt.outputToken, + tt.amountIn, + tt.routeArr, + tt.quoteArr, + tt.amountOutMin, + time.Now().Add(time.Hour).Unix(), + ) + + if !tt.wantErr { + if amountIn == "" || amountOut == "" { + t.Errorf("ExactInSwapRoute() returned empty values") + } + } + }) + } +} diff --git a/contract/r/gnoswap/router/exact_out.gno b/contract/r/gnoswap/router/exact_out.gno new file mode 100644 index 000000000..61ef61578 --- /dev/null +++ b/contract/r/gnoswap/router/exact_out.gno @@ -0,0 +1,130 @@ +package router + +import ( + "std" + + "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +type ExactOutSwapOperation struct { + baseSwapOperation + params ExactOutParams +} + +func NewExactOutSwapOperation(pp ExactOutParams) *ExactOutSwapOperation { + return &ExactOutSwapOperation{ + params: pp, + baseSwapOperation: baseSwapOperation{ + userWrappedWugnot: INITIAL_WUGNOT_BALANCE, + }, + } +} + +func ExactOutSwapRoute( + inputToken string, + outputToken string, + amountOut string, + routeArr string, + quoteArr string, + amountInMax string, + deadline int64, +) (string, string) { + checkDeadline(deadline) + commonSwapSetup() + + baseParams := BaseSwapParams{ + InputToken: inputToken, + OutputToken: outputToken, + RouteArr: routeArr, + QuoteArr: quoteArr, + } + + pp := NewExactOutParams(baseParams, amountOut, amountInMax) + op := NewExactOutSwapOperation(pp) + + result, err := executeSwapOperation(op) + if err != nil { + panic(addDetailToError(errInvalidInput, err.Error())) + } + + inputAmount, outputAmount := finalizeSwap( + pp.InputToken, + pp.OutputToken, + result.AmountIn, + result.AmountOut, + ExactOut, + u256.MustFromDecimal(pp.AmountInMax), + op.userBeforeWugnotBalance, + op.userWrappedWugnot, + result.AmountSpecified.Abs(), + ) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "ExactOutSwap", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "input", pp.InputToken, + "output", pp.OutputToken, + "exactAmount", amountOut, + "route", pp.RouteArr, + "quote", pp.QuoteArr, + "resultInputAmount", inputAmount, + "resultOutputAmount", outputAmount, + ) + + return inputAmount, outputAmount +} + +func (op *ExactOutSwapOperation) Validate() error { + amountOut := i256.MustFromDecimal(op.params.AmountOut) + if amountOut.IsZero() || amountOut.IsNeg() { + return ufmt.Errorf("invalid amountOut(%s), must be positive", amountOut.ToString()) + } + + // assign a signed reversed `amountOut` to `amountSpecified` + // when it's an ExactOut + op.amountSpecified = i256.Zero().Neg(amountOut) + + routes, quotes, err := tryParseRoutes(op.params.RouteArr, op.params.QuoteArr) + if err != nil { + return err + } + + op.routes = routes + op.quotes = quotes + + return nil +} + +func (op *ExactOutSwapOperation) Process() (*SwapResult, error) { + if err := op.handleNativeTokenWrapping(); err != nil { + return nil, err + } + + resultAmountIn, resultAmountOut, err := op.processRoutes(ExactOut) + if err != nil { + return nil, err + } + + return &SwapResult{ + AmountIn: resultAmountIn, + AmountOut: resultAmountOut, + Routes: op.routes, + Quotes: op.quotes, + AmountSpecified: op.amountSpecified, + }, nil +} + +func (op *ExactOutSwapOperation) handleNativeTokenWrapping() error { + return op.baseSwapOperation.handleNativeTokenWrapping( + op.params.InputToken, + op.params.OutputToken, + ExactOut, + op.amountSpecified, + ) +} diff --git a/contract/r/gnoswap/router/gno.mod b/contract/r/gnoswap/router/gno.mod new file mode 100644 index 000000000..e3a08288b --- /dev/null +++ b/contract/r/gnoswap/router/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/router diff --git a/contract/r/gnoswap/router/protocol_fee_swap.gno b/contract/r/gnoswap/router/protocol_fee_swap.gno new file mode 100644 index 000000000..f4eb4e888 --- /dev/null +++ b/contract/r/gnoswap/router/protocol_fee_swap.gno @@ -0,0 +1,122 @@ +package router + +import ( + "std" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/p/demo/ufmt" + + u256 "gno.land/p/gnoswap/uint256" + + pf "gno.land/r/gnoswap/v1/protocol_fee" +) + +const ( + defaultSwapFeeBPS = uint64(15) // 0.15% +) + +var ( + swapFee = defaultSwapFeeBPS +) + +// GetSwapFee returns current rate of swap fee +// ref: https://docs.gnoswap.io/contracts/router/protocol_fee_swap.gno#getswapfee +func GetSwapFee() uint64 { + return swapFee +} + +// SetSwapFeeByAdmin modifies the swap fee +// Only admin can execute this function +func SetSwapFeeByAdmin(fee uint64) { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err) + } + + prevSwapFee := swapFee + if err := setSwapFee(fee); err != nil { + panic(err) + } + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "SetSwapFeeByAdmin", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "newFee", formatUint(fee), + "prevFee", formatUint(prevSwapFee), + ) +} + +// SetSwapFee modifies the swap fee +// Only governance contract can execute this function via proposal +// ref: https://docs.gnoswap.io/contracts/router/protocol_fee_swap.gno#setswapfee +func SetSwapFee(fee uint64) { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err) + } + + prevSwapFee := swapFee + if err := setSwapFee(fee); err != nil { + panic(err) + } + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "SetSwapFee", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "newFee", formatUint(fee), + "prevFee", formatUint(prevSwapFee), + ) +} + +func setSwapFee(fee uint64) error { + common.IsHalted() + + // 10000 (bps) = 100% + if fee > 10000 { + return ufmt.Errorf( + "%s: fee must be in range 0 to 10000. got %d", + errInvalidSwapFee.Error(), fee, + ) + } + + swapFee = fee + return nil +} + +func handleSwapFee( + outputToken string, + amount *u256.Uint, +) *u256.Uint { + if swapFee <= 0 { + return amount + } + + feeAmount := new(u256.Uint).Mul(amount, u256.NewUint(swapFee)) + feeAmount.Div(feeAmount, u256.NewUint(10000)) + feeAmountUint64 := feeAmount.Uint64() + + outputTeller := common.GetTokenTeller(outputToken) + outputTeller.TransferFrom(std.PrevRealm().Addr(), consts.PROTOCOL_FEE_ADDR, feeAmountUint64) + pf.AddToProtocolFee(outputToken, feeAmountUint64) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "SwapRouteFee", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "tokenPath", outputToken, + "amount", formatUint(feeAmountUint64), + ) + + toUserAfterProtocol := new(u256.Uint).Sub(amount, feeAmount) + return toUserAfterProtocol +} diff --git a/contract/r/gnoswap/router/protocol_fee_swap_test.gno b/contract/r/gnoswap/router/protocol_fee_swap_test.gno new file mode 100644 index 000000000..372123e63 --- /dev/null +++ b/contract/r/gnoswap/router/protocol_fee_swap_test.gno @@ -0,0 +1,57 @@ +package router + +import ( + "testing" + + u256 "gno.land/p/gnoswap/uint256" +) + +func TestHandleSwapFee(t *testing.T) { + tests := []struct { + name string + amount *u256.Uint + swapFeeValue uint64 + expectedAmount *u256.Uint + }{ + { + name: "zero swap fee", + amount: u256.NewUint(1000), + swapFeeValue: 0, + expectedAmount: u256.NewUint(1000), + }, + { + name: "normal swap fee calculation (0.15%)", + amount: u256.NewUint(10000), + swapFeeValue: 15, + expectedAmount: u256.NewUint(9985), // 10000 - (10000 * 0.15%) + }, + { + name: "Dry Run test", + amount: u256.NewUint(10000), + swapFeeValue: 15, + expectedAmount: u256.NewUint(9985), + }, + { + name: "large amount swap fee calculation", + amount: u256.NewUint(1000000), + swapFeeValue: 15, + expectedAmount: u256.NewUint(998500), // 1000000 - (1000000 * 0.15%) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + originalSwapFee := swapFee + swapFee = tt.swapFeeValue + defer func() { + swapFee = originalSwapFee + }() + + result := handleSwapFee(barPath, tt.amount) + + if !result.Eq(tt.expectedAmount) { + t.Errorf("handleSwapFee() = %v, want %v", result, tt.expectedAmount) + } + }) + } +} diff --git a/contract/r/gnoswap/router/router.gno b/contract/r/gnoswap/router/router.gno new file mode 100644 index 000000000..44510a417 --- /dev/null +++ b/contract/r/gnoswap/router/router.gno @@ -0,0 +1,199 @@ +package router + +import ( + "std" + "strconv" + + "gno.land/p/demo/ufmt" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/common" + + en "gno.land/r/gnoswap/v1/emission" +) + +// commonSwapSetup Common validation and setup logic extracted from SwapRoute +func commonSwapSetup() { + assertOnlyNotHalted() + assertDirectCallOnly() + + en.MintAndDistributeGns() +} + +// handleSingleSwap handles a single swap operation. +func handleSingleSwap(route string, amountSpecified *i256.Int) (*u256.Uint, *u256.Uint) { + input, output, fee := getDataForSinglePath(route) + singleParams := SingleSwapParams{ + tokenIn: input, + tokenOut: output, + fee: fee, + amountSpecified: amountSpecified, + } + + return singleSwap(singleParams) +} + +func handleMultiSwap( + swapType SwapType, + route string, + numHops int, + amountSpecified *i256.Int, +) (*u256.Uint, *u256.Uint) { + switch swapType { + case ExactIn: + input, output, fee := getDataForMultiPath(route, 0) // first data + swapParams := SwapParams{ + tokenIn: input, + tokenOut: output, + fee: fee, + recipient: getPrevAddr(), + amountSpecified: amountSpecified, + } + return multiSwap(swapParams, 0, numHops, route) + case ExactOut: + input, output, fee := getDataForMultiPath(route, numHops-1) // last data + swapParams := SwapParams{ + tokenIn: input, + tokenOut: output, + fee: fee, + recipient: getPrevAddr(), + amountSpecified: amountSpecified, + } + return multiSwapNegative(swapParams, numHops-1, route) + default: + // Any invalid `SwapType` is caught in the `SwapRoute` function, + // so no invalid values can get in here. + panic(addDetailToError( + errInvalidSwapType, + ufmt.Sprintf("unknown swapType(%s)", swapType), + )) + } +} + +func finalizeSwap( + inputToken, outputToken string, + resultAmountIn, resultAmountOut *u256.Uint, + swapType SwapType, + tokenAmountLimit *u256.Uint, + userBeforeWugnotBalance, userWrappedWugnot uint64, + amountSpecified *u256.Uint, +) (string, string) { + + if swapType == ExactOut { + // If the pool's raw output is less than user wants, fail fast. + // (Optional: some designs skip this, and only check afterFee.) + if resultAmountOut.Lt(amountSpecified) { + panic(addDetailToError( + errSlippage, + ufmt.Sprintf("Received more than requested in [EXACT_OUT] requested=%s, actual=%s", amountSpecified.ToString(), resultAmountOut.ToString()), + )) + } + } + + afterFee := handleSwapFee(outputToken, resultAmountOut) + + userNewWugnotBalance := wugnot.BalanceOf(common.AddrToUser(std.PrevRealm().Addr())) + if inputToken == consts.GNOT { + totalBefore := userBeforeWugnotBalance + userWrappedWugnot + spend := totalBefore - userNewWugnotBalance + + if spend > userWrappedWugnot { + // used existing wugnot + panic(addDetailToError( + errSlippage, + ufmt.Sprintf("too much wugnot spent (wrapped: %d, spend: %d)", userWrappedWugnot, spend), + )) + } + + // unwrap left amount + toUnwrap := userWrappedWugnot - spend + unwrap(toUnwrap) + } else if outputToken == consts.GNOT { + userRecvWugnot := uint64(userNewWugnotBalance - userBeforeWugnotBalance - userWrappedWugnot) + unwrap(userRecvWugnot) + } + + if swapType == ExactIn { + // The user gave a fixed input => we must ensure final out >= tokenAmountLimit + // afterFee is the actual tokens user receives + if afterFee.Lt(tokenAmountLimit) { + panic(addDetailToError( + errSlippage, + ufmt.Sprintf("ExactIn: too few received (min:%s, got:%s)", tokenAmountLimit.ToString(), afterFee.ToString()), + )) + } + } else { + // swapType == ExactOut + // The user wants to get at least "amountSpecified" final tokens, + if resultAmountIn.Gt(tokenAmountLimit) { + panic(addDetailToError( + errSlippage, + ufmt.Sprintf("ExactOut: too much spent (max:%s, used:%s)", tokenAmountLimit.ToString(), resultAmountIn.ToString()), + )) + } + } + + intAmountOut := i256.FromUint256(afterFee) // final user out + negativeOut := i256.Zero().Neg(intAmountOut).ToString() + + return resultAmountIn.ToString(), negativeOut +} + +// validateRoutesAndQuotes validates the routes and quotes slices based on specific criteria. +// +// Parameters: +// - routes: A slice of strings representing route identifiers. +// - quotes: A slice of strings representing quote percentages corresponding to each route. +// +// Returns: +// - error: An error if the validation fails; nil if all checks pass. +func validateRoutesAndQuotes(routes, quotes []string) error { + if len(routes) < 1 || len(routes) > 7 { + return ufmt.Errorf("route length(%d) must be 1~7", len(routes)) + } + + if len(routes) != len(quotes) { + return ufmt.Errorf("mismatch between routes(%d) and quotes(%d) length", len(routes), len(quotes)) + } + + var quotesSum int + + for i, quote := range quotes { + intQuote, err := strconv.Atoi(quote) + if err != nil { + return ufmt.Errorf("invalid quote(%s) at index(%d)", quote, i) + } + + quotesSum += intQuote + } + + if quotesSum != 100 { + return ufmt.Errorf("quote sum(%d) must be 100", quotesSum) + } + + return nil +} + +// tryParseRoutes parses and validates routes and quotes strings, returning them as slices. +// +// Parameters: +// - routes: A string containing route identifiers separated by commas (e.g., "route1,route2,route3"). +// - quotes: A string containing quote identifiers separated by commas (e.g., "quote1,quote2,quote3"). +// +// Returns: +// - []string: A slice of route identifiers parsed from the `routes` string. +// - []string: A slice of quote identifiers parsed from the `quotes` string. +// - error: An error if the validation between routes and quotes fails. +func tryParseRoutes(routes, quotes string) ([]string, []string, error) { + routesArr := splitSingleChar(routes, ',') + quotesArr := splitSingleChar(quotes, ',') + + if err := validateRoutesAndQuotes(routesArr, quotesArr); err != nil { + return nil, nil, err + } + + return routesArr, quotesArr, nil +} diff --git a/contract/r/gnoswap/router/router_dry.gno b/contract/r/gnoswap/router/router_dry.gno new file mode 100644 index 000000000..f7c319cd8 --- /dev/null +++ b/contract/r/gnoswap/router/router_dry.gno @@ -0,0 +1,170 @@ +package router + +import ( + "std" + "strconv" + "strings" + + "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" +) + +// DrySwapRoute simulates a token swap route without actually executing the swap. +// It calculates the expected outcome based on the current state of liquidity pools. +// Returns the expected amount in or out +func DrySwapRoute( + inputToken string, + outputToken string, + specifiedAmount string, + swapKind string, + strRouteArr string, + quoteArr string, + tokenAmountLimit string, +) string { + common.MustRegistered(inputToken) + common.MustRegistered(outputToken) + + swapType, err := trySwapTypeFromStr(swapKind) + if err != nil { + panic(addDetailToError( + errInvalidSwapType, + ufmt.Sprintf("unknown swapType(%s)", swapKind), + )) + } + + amountSpecified := i256.MustFromDecimal(specifiedAmount) + if amountSpecified.IsZero() || amountSpecified.IsNeg() { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("invalid amountSpecified(%s), must be positive", specifiedAmount), + )) + } + + amountLimit := i256.MustFromDecimal(tokenAmountLimit) + if amountLimit.IsZero() { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("invalid amountLimit(%s), should not be zero", tokenAmountLimit), + )) + } + + routes, quotes, err := tryParseRoutes(strRouteArr, quoteArr) + if err != nil { + panic(addDetailToError( + errInvalidRoutesAndQuotes, + err.Error()), + ) + } + + if swapType == ExactOut { + amountSpecified = i256.Zero().Neg(amountSpecified) + } + + resultAmountIn := u256.Zero() + resultAmountOut := u256.Zero() + + for i, route := range routes { + numHops := strings.Count(route, POOL_SEPARATOR) + 1 + assertHopsInRange(numHops) + // don't need to check error here + quote, _ := strconv.Atoi(quotes[i]) + if quote < 0 || quote > 100 { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("quote(%d) must be 0~100", quote), + )) + } + + toSwap := i256.Zero().Mul(amountSpecified, i256.NewInt(int64(quote))) + toSwap = new(i256.Int).Div(toSwap, i256.NewInt(100)) + + if numHops == 1 { + amountIn, amountOut := handleSingleDrySwap(route, toSwap) + resultAmountIn = new(u256.Uint).Add(resultAmountIn, amountIn) + resultAmountOut = new(u256.Uint).Add(resultAmountOut, amountOut) + } else { + amountIn, amountOut := handleMultiDrySwap(swapType, route, numHops, toSwap) + resultAmountIn = new(u256.Uint).Add(resultAmountIn, amountIn) + resultAmountOut = new(u256.Uint).Add(resultAmountOut, amountOut) + } + } + + return processResult(swapType, resultAmountIn, resultAmountOut, amountSpecified, amountLimit) +} + +func handleSingleDrySwap(route string, amountSpecified *i256.Int) (*u256.Uint, *u256.Uint) { + input, output, fee := getDataForSinglePath(route) + singleParams := SingleSwapParams{ + tokenIn: input, + tokenOut: output, + fee: fee, + amountSpecified: amountSpecified, + } + + return singleDrySwap(singleParams) +} + +func handleMultiDrySwap( + swapType SwapType, + route string, + numHops int, + amountSpecified *i256.Int, +) (*u256.Uint, *u256.Uint) { + switch swapType { + case ExactIn: + input, output, fee := getDataForMultiPath(route, 0) // first data + swapParams := SwapParams{ + tokenIn: input, + tokenOut: output, + fee: fee, + recipient: std.PrevRealm().Addr(), + amountSpecified: amountSpecified, + } + return multiDrySwap(swapParams, 0, numHops, route) + case ExactOut: + input, output, fee := getDataForMultiPath(route, numHops-1) // last data + swapParams := SwapParams{ + tokenIn: input, + tokenOut: output, + fee: fee, + recipient: std.PrevRealm().Addr(), + amountSpecified: amountSpecified, + } + return multiDrySwapNegative(swapParams, numHops-1, route) + default: + panic(addDetailToError( + errInvalidSwapType, + ufmt.Sprintf("unknown swapType(%s)", swapType), + )) + } +} + +func processResult(swapType SwapType, resultAmountIn, resultAmountOut *u256.Uint, amountSpecified, amountLimit *i256.Int) string { + switch swapType { + case ExactIn: + if i256.FromUint256(resultAmountIn).Gt(amountSpecified) { + return "-1" + } + if i256.FromUint256(resultAmountOut).Lt(amountLimit) { + return "-1" + } + return resultAmountOut.ToString() + case ExactOut: + if i256.FromUint256(resultAmountOut).Lt(amountSpecified) { + return "-1" + } + if i256.FromUint256(resultAmountIn).Gt(amountLimit) { + return "-1" + } + return resultAmountIn.ToString() + default: + panic(addDetailToError( + errInvalidSwapType, + ufmt.Sprintf("unknown swapType(%s)", swapType), + )) + } +} diff --git a/contract/r/gnoswap/router/router_dry_test.gno b/contract/r/gnoswap/router/router_dry_test.gno new file mode 100644 index 000000000..4ef062637 --- /dev/null +++ b/contract/r/gnoswap/router/router_dry_test.gno @@ -0,0 +1,69 @@ +package router + +import ( + "testing" + + "gno.land/p/demo/uassert" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestProcessResult(t *testing.T) { + tests := []struct { + name string + swapType SwapType + resultAmountIn string + resultAmountOut string + amountSpecified string + expected string + }{ + { + name: "ExactIn - Normal", + swapType: ExactIn, + resultAmountIn: "100", + resultAmountOut: "95", + amountSpecified: "100", + expected: "95", + }, + { + name: "ExactIn - Input Mismatch", + swapType: ExactIn, + resultAmountIn: "99", + resultAmountOut: "5", + amountSpecified: "100", + expected: "-1", + }, + { + name: "ExactOut - Normal", + swapType: ExactOut, + resultAmountIn: "105", + resultAmountOut: "100", + amountSpecified: "100", + expected: "105", + }, + { + name: "ExactOut - Output Mismatch", + swapType: ExactOut, + resultAmountIn: "105", + resultAmountOut: "95", + amountSpecified: "100", + expected: "-1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resultAmountIn, _ := u256.FromDecimal(tt.resultAmountIn) + resultAmountOut, _ := u256.FromDecimal(tt.resultAmountOut) + amountSpecified, _ := i256.FromDecimal(tt.amountSpecified) + amountLimit, _ := i256.FromDecimal("10") + if tt.swapType == ExactOut { + amountLimit, _ = i256.FromDecimal("500") + } + + result := processResult(tt.swapType, resultAmountIn, resultAmountOut, amountSpecified, amountLimit) + uassert.Equal(t, result, tt.expected) + }) + } +} diff --git a/contract/r/gnoswap/router/router_test.gno b/contract/r/gnoswap/router/router_test.gno new file mode 100644 index 000000000..e2c25fa77 --- /dev/null +++ b/contract/r/gnoswap/router/router_test.gno @@ -0,0 +1,132 @@ +package router + +import ( + "strings" + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/gnoswap/consts" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestFinalizeSwap(t *testing.T) { + gnot := consts.GNOT + + newUint256 := func(val string) *u256.Uint { + return u256.MustFromDecimal(val) + } + + tests := []struct { + name string + inputToken string + outputToken string + resultAmountIn *u256.Uint + resultAmountOut *u256.Uint + swapType SwapType + tokenAmountLimit *u256.Uint + userBeforeWugnotBalance uint64 + userWrappedWugnot uint64 + amountSpecified *u256.Uint + expectError bool + errorMessage string + }{ + { + name: "Pass: ExactIn", + inputToken: barPath, + outputToken: bazPath, + resultAmountIn: newUint256("100"), + resultAmountOut: newUint256("90"), + swapType: ExactIn, + tokenAmountLimit: newUint256("85"), + userBeforeWugnotBalance: 0, + userWrappedWugnot: 0, + amountSpecified: newUint256("100"), + expectError: false, + }, + { + name: "Pass: ExactOut", + inputToken: barPath, + outputToken: bazPath, + resultAmountIn: newUint256("110"), + resultAmountOut: newUint256("100"), + swapType: ExactOut, + tokenAmountLimit: newUint256("120"), + userBeforeWugnotBalance: 0, + userWrappedWugnot: 0, + amountSpecified: newUint256("100"), + expectError: false, + }, + { + name: "ExactOut: Slippage error", + inputToken: barPath, + outputToken: bazPath, + resultAmountIn: newUint256("100"), + resultAmountOut: newUint256("90"), + swapType: ExactOut, + tokenAmountLimit: newUint256("100"), + userBeforeWugnotBalance: 0, + userWrappedWugnot: 0, + amountSpecified: newUint256("100"), + expectError: true, + errorMessage: "[GNOSWAP-ROUTER-002] slippage check failed || Received more than requested in [EXACT_OUT] requested=100, actual=90", + }, + { + name: "GNOT: Slippage error", + inputToken: gnot, + outputToken: barPath, + resultAmountIn: newUint256("300"), + resultAmountOut: newUint256("90"), + swapType: ExactIn, + tokenAmountLimit: newUint256("85"), + userBeforeWugnotBalance: 1000, + userWrappedWugnot: 200, + expectError: true, + errorMessage: "too much wugnot spent", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.expectError { + defer func() { + r := recover() + if r == nil { + t.Errorf("Error expected but not occurred") + return + } + errorStr, ok := r.(string) + if !ok { + t.Errorf("Unexpected error type: %v", r) + return + } + if tt.errorMessage != "" && !strings.Contains(errorStr, tt.errorMessage) { + t.Errorf("Expected error message not included. got: %v, want: %v", errorStr, tt.errorMessage) + } + }() + } + + amountIn, amountOut := finalizeSwap( + tt.inputToken, + tt.outputToken, + tt.resultAmountIn, + tt.resultAmountOut, + tt.swapType, + tt.tokenAmountLimit, + tt.userBeforeWugnotBalance, + tt.userWrappedWugnot, + tt.amountSpecified, + ) + + if !tt.expectError { + uassert.NotEqual(t, amountIn, "") + uassert.NotEqual(t, amountOut, "") + + outVal := i256.MustFromDecimal(amountOut) + if !outVal.IsNeg() { + t.Error("amountOut is not negative") + } + } + }) + } +} diff --git a/contract/r/gnoswap/router/swap_inner.gno b/contract/r/gnoswap/router/swap_inner.gno new file mode 100644 index 000000000..f1967a33e --- /dev/null +++ b/contract/r/gnoswap/router/swap_inner.gno @@ -0,0 +1,309 @@ +package router + +import ( + "std" + + "gno.land/p/demo/ufmt" + + pl "gno.land/r/gnoswap/v1/pool" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +// swapInner executes the core swap logic by interacting with the pool contract. +// This is the main implementation of token swapping that handles both exact input and output swaps. +// +// Expected behavior: +// - Forexact input swaps: First return value is the exact input amount +// - For exact output swaps: Second return value is the exact output amount +// - Both return values are always positive, regardless of swap direction +// +// Parameters: +// - amountSpecified: Amount specified for the swap (positive for exact input, negative for exact output) +// - recipient: Address that will receive the output tokens +// - sqrtPriceLimitX96: Optional price limit for the swap operation +// - data: SwapCallbackData containing additional swap information +// +// Returns: +// - *u256.Uint: Total amount of input tokens used +// - *u256.Uint: Total amount of output tokens received +func swapInner( + amountSpecified *i256.Int, + recipient std.Address, + sqrtPriceLimitX96 *u256.Uint, + data SwapCallbackData, +) (*u256.Uint, *u256.Uint) { // poolRecv, poolOut + zeroForOne := data.tokenIn < data.tokenOut + + sqrtPriceLimitX96 = calculateSqrtPriceLimitForSwap(zeroForOne, data.fee, sqrtPriceLimitX96) + + // ROUTER approves POOL as spender + tokenIn := common.GetTokenTeller(data.tokenIn) + tokenOut := common.GetTokenTeller(data.tokenOut) + tokenIn.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + tokenOut.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + amount0Str, amount1Str := pl.Swap( // int256, int256 + data.tokenIn, + data.tokenOut, + data.fee, + recipient, + zeroForOne, + amountSpecified.ToString(), + sqrtPriceLimitX96.ToString(), + data.payer, + ) + + amount0 := i256.MustFromDecimal(amount0Str) + amount1 := i256.MustFromDecimal(amount1Str) + + poolRecv := i256Max(amount0, amount1) + poolOut := i256Min(amount0, amount1) + + return poolRecv.Abs(), poolOut.Abs() +} + +// swapDryInner performs a dry-run of a swap operation, calculating the potential +// received and output amounts without actually executing the swap. +// +// Parameters: +// - amountSpecified: The amount specified for the swap operation. It is an `i256.Int` +// representing the input token amount or output token amount (depending on the swap direction). +// - sqrtPriceLimitX96: The price limit of the swap expressed as a square root price in X96 format. +// This can be a user-defined limit or set to a default min or max price if not provided. +// - data: A `SwapCallbackData` structure containing details about the tokens being swapped, +// the fee structure, and other context. +// +// Returns: +// - poolRecv: The absolute value of the maximum amount received by the pool during the swap. +// - poolOut: The absolute value of the minimum amount output from the pool during the swap. +func swapDryInner( + amountSpecified *i256.Int, + sqrtPriceLimitX96 *u256.Uint, + data SwapCallbackData, +) (*u256.Uint, *u256.Uint) { // poolRecv, poolOut + zeroForOne := data.tokenIn < data.tokenOut + + if sqrtPriceLimitX96.IsZero() { + if zeroForOne { + sqrtPriceLimitX96 = u256.MustFromDecimal(consts.MIN_PRICE) + } else { + sqrtPriceLimitX96 = u256.MustFromDecimal(consts.MAX_PRICE) + } + } + + // check possible + amount0Str, amount1Str, ok := pl.DrySwap( + data.tokenIn, + data.tokenOut, + data.fee, + zeroForOne, + amountSpecified.ToString(), + sqrtPriceLimitX96.ToString(), + ) + if !ok { + return u256.Zero(), u256.Zero() + } + + amount0 := i256.MustFromDecimal(amount0Str) + amount1 := i256.MustFromDecimal(amount1Str) + + poolRecv := i256Max(amount0, amount1) + poolOut := i256Min(amount0, amount1) + + return poolRecv.Abs(), poolOut.Abs() +} + +// calculateSqrtPriceLimitForSwap calculates the price limit for a swap operation. +// This function uses the tick ranges defined by `getMinTick` and `getMaxTick` to set price boundaries. +// +// Price Boundary Visualization: +// ``` +// +// MIN_TICK MAX_TICK +// v v +// <--|---------------------------|--> +// ^ ^ +// zeroForOne oneForZero +// limit + 1 limit - 1 +// +// ``` +// +// Implementation details: +// - If a non-zero sqrtPriceLimitX96 is provided, it's used as-is +// - For zeroForOne swaps (tokenIn < tokenOut): +// - Uses the minimum tick for the fee tier +// - Adds 1 to avoid hitting the exact boundary +// - For oneForZero swaps (tokenIn > tokenOut): +// - Uses the maximum tick for the fee tier +// - Subtracts 1 to avoid hitting the exact boundary +// +// Parameters: +// - zeroForOne: Boolean indicating the swap direction (true for zeroForOne, false for oneForZero) +// - fee: Fee tier of the pool in basis points +// - sqrtPriceLimitX96: Optional price limit for the swap operation +// +// Returns: +// - *u256.Uint: Calculated price limit for the swap operation +func calculateSqrtPriceLimitForSwap(zeroForOne bool, fee uint32, sqrtPriceLimitX96 *u256.Uint) *u256.Uint { + if !sqrtPriceLimitX96.IsZero() { + return sqrtPriceLimitX96 + } + + if zeroForOne { + minTick := getMinTick(fee) + sqrtPriceLimitX96 = common.TickMathGetSqrtRatioAtTick(minTick) + return new(u256.Uint).Add(sqrtPriceLimitX96, u256.One()) + } + + maxTick := getMaxTick(fee) + sqrtPriceLimitX96 = common.TickMathGetSqrtRatioAtTick(maxTick) + return new(u256.Uint).Sub(sqrtPriceLimitX96, u256.One()) +} + +// getMinTick returns the minimum tick value for a given fee tier. +// The implementation follows Uniswap V3's tick spacing rules where +// lower fee tiers allows for finer price granularity. +// +// Fee tier to min tick mapping demonstrates varying levels of price granularity: +// +// ## How these values are calculated? +// +// The Tick bounds in Uniswap V3 are derived from the desired price range and precisions: +// 1. Price Range: Uniswap V3 uses the formula price = 1.0001^tick +// 2. The minimum tick is calculated to represent a very small but non-zero price: +// - Let min_tick = log(minimum_price) / log(1.0001) +// - The minimum price is chosen to be 2^-128 ≈ 2.9387e-39 +// - Therefor, min_tick = log(2^-128) / log(1.0001) ≈ -887272 +// +// ### Tick Spacing Adjustment +// +// - Each fee tier has different tick spacing for efficiency +// - The actual minimum tick is rounded to the nearest tick spacing: +// - 0.01% fee -> spacing of 1 -> -887272 +// - 0.05% fee -> spacing of 10 -> -887270 +// - 0.30% fee -> spacing of 60 -> -887220 +// - 1.00% fee -> spacing of 200 -> -887200 +// +// ## Tick Range Visualization: +// +// ``` +// +// 0 +// Fee Tier Min Tick | Max Tick Tick Spacing +// +// 0.01% (100) -887272 | 887272 1 finest +// +// | +// +// 0.05% (500) -887270 | 887270 10 +// +// | +// +// 0.3% (3000) -887220 | 887220 60 +// +// | +// +// 1% (10000) -887200 | 887200 200 coarsest +// +// | +// +// Price Range: | +// +// ``` +// +// Tick spacing determines the granularity of price points: +// +// - Smaller tick spacing (1) = More precise price points +// Example for 0.01% fee tier: +// ``` +// Tick: -887272 [...] -2, -1, 0, 1, 2 [...] 887272 +// Steps: 1 1 1 1 1 1 1 +// ``` +// +// - Larger tick spacing (200) = Fewer, more spread out price points +// Example for 1% fee tier: +// ``` +// Tick: -887200 [...] -400, -200, 0, 200, 400 [...] 887200 +// Steps: 200 200 200 200 200 200 200 +// ``` +// +// This function returns the minimum tick value for a given fee tier. +// +// Parameters: +// - fee: Fee tier in basis points +// +// Returns: +// - int32: Minimum tick value for the given fee tier +// +// Panic: +// - If the fee tier is not supported +// +// Reference: +// - https://blog.uniswap.org/uniswap-v3-math-primer +func getMinTick(fee uint32) int32 { + switch fee { + case 100: + return -887272 + case 500: + return -887270 + case 3000: + return -887220 + case 10000: + return -887200 + default: + panic(addDetailToError( + errInvalidPoolFeeTier, + ufmt.Sprintf("unknown fee(%d)", fee), + )) + } +} + +// getMaxTick returns the maximum tick value for a given fee tier. +// +// ## How these values are calculated? +// +// The max tick values are the exact negatives of min tick values because: +// 1. Price symmetry: If min_price = 2^-128, then max_price = 2^128 +// 2. Using the same formula: max_tick = log(2^128) / log(1.0001) ≈ 887272 +// +// ### Tick Spacing Relationship: +// +// The max ticks follow the same spacing rules as min ticks: +// - 0.01% fee -> +887272 (finest granularity) +// - 0.05% fee -> +887270 (10-tick spacing) +// - 0.30% fee -> +887220 (60-tick spacing) +// - 1.00% fee -> +887200 (coarsest granularity) +// +// Parameters: +// - fee: Fee tier in basis points +// +// Returns: +// - int32: Maximum tick value for the given fee tier +// +// Panic: +// - If the fee tier is not supported +// +// Reference: +// - https://blog.uniswap.org/uniswap-v3-math-primer +func getMaxTick(fee uint32) int32 { + switch fee { + case 100: + return 887272 + case 500: + return 887270 + case 3000: + return 887220 + case 10000: + return 887200 + default: + panic(addDetailToError( + errInvalidPoolFeeTier, + ufmt.Sprintf("unknown fee(%d)", fee), + )) + } +} diff --git a/contract/r/gnoswap/router/swap_inner_test.gno b/contract/r/gnoswap/router/swap_inner_test.gno new file mode 100644 index 000000000..c11b457c0 --- /dev/null +++ b/contract/r/gnoswap/router/swap_inner_test.gno @@ -0,0 +1,131 @@ +package router + +import ( + "std" + "testing" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + + "gno.land/p/demo/uassert" +) + +func TestCalculateSqrtPriceLimitForSwap(t *testing.T) { + tests := []struct { + name string + zeroForOne bool + fee uint32 + sqrtPriceLimitX96 *u256.Uint + expected *u256.Uint + }{ + { + name: "already set sqrtPriceLimit", + zeroForOne: true, + fee: 500, + sqrtPriceLimitX96: u256.NewUint(1000), + expected: u256.NewUint(1000), + }, + { + name: "when zeroForOne is true, calculate min tick", + zeroForOne: true, + fee: 500, + sqrtPriceLimitX96: u256.Zero(), + expected: common.TickMathGetSqrtRatioAtTick(getMinTick(500)).Add( + common.TickMathGetSqrtRatioAtTick(getMinTick(500)), + u256.One(), + ), + }, + { + name: "when zeroForOne is false, calculate max tick", + zeroForOne: false, + fee: 500, + sqrtPriceLimitX96: u256.Zero(), + expected: common.TickMathGetSqrtRatioAtTick(getMaxTick(500)).Sub( + common.TickMathGetSqrtRatioAtTick(getMaxTick(500)), + u256.One(), + ), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := calculateSqrtPriceLimitForSwap( + tt.zeroForOne, + tt.fee, + tt.sqrtPriceLimitX96, + ) + uassert.Equal(t, result.ToString(), tt.expected.ToString()) + }) + } +} + +func TestSwapInner(t *testing.T) { + user1Realm := std.NewUserRealm(user1Addr) + + tests := []struct { + name string + setupFn func(t *testing.T) + amountSpecified *i256.Int + recipient std.Address + sqrtPriceLimitX96 *u256.Uint + data SwapCallbackData + expectedRecv string + expectedOut string + expectError bool + }{ + { + name: "normal swap - exact input", + setupFn: func(t *testing.T) { + CreatePoolWithoutFee(t) + MakeMintPositionWithoutFee(t) + + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + TokenFaucet(t, barPath, user1Addr) + TokenFaucet(t, bazPath, user1Addr) + }, + amountSpecified: i256.MustFromDecimal("100"), + recipient: alice, + sqrtPriceLimitX96: u256.NewUint(4295128740), + data: SwapCallbackData{ + tokenIn: barPath, + tokenOut: bazPath, + fee: 3000, + payer: consts.ROUTER_ADDR, + }, + expectedRecv: "100", + expectedOut: "98", + expectError: false, + }, + } + + for _, tt := range tests { + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + bar.Approve(consts.POOL_ADDR, maxApprove) + baz.Approve(consts.POOL_ADDR, maxApprove) + TokenFaucet(t, barPath, consts.ROUTER_ADDR) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + poolRecv, poolOut := swapInner( + tt.amountSpecified, + tt.recipient, + tt.sqrtPriceLimitX96, + tt.data, + ) + + uassert.Equal(t, poolRecv.ToString(), tt.expectedRecv) + uassert.Equal(t, poolOut.ToString(), tt.expectedOut) + } +} diff --git a/contract/r/gnoswap/router/swap_multi.gno b/contract/r/gnoswap/router/swap_multi.gno new file mode 100644 index 000000000..2dcca2ae6 --- /dev/null +++ b/contract/r/gnoswap/router/swap_multi.gno @@ -0,0 +1,208 @@ +package router + +import ( + "std" + + "gno.land/p/gnoswap/consts" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func multiSwap(params SwapParams, currentPoolIndex, numPools int, swapPath string) (*u256.Uint, *u256.Uint) { // firstAmountIn, lastAmountOut + firstAmountIn := u256.Zero() + + payer := getPrevAddr() // user + + for { + var recipient std.Address + currentPoolIndex++ + + if currentPoolIndex < numPools { + recipient = consts.ROUTER_ADDR + } else { + recipient = params.recipient + } + + amountIn, amountOut := swapInner( + params.amountSpecified, + recipient, + u256.Zero(), + SwapCallbackData{ + params.tokenIn, + params.tokenOut, + params.fee, + payer, + }, + ) + + if currentPoolIndex == 1 { + firstAmountIn = amountIn + } + + if currentPoolIndex >= numPools { + return firstAmountIn, amountOut + } + + payer = consts.ROUTER_ADDR + + nextInput, nextOutput, nextFee := getDataForMultiPath(swapPath, currentPoolIndex) + params.tokenIn = nextInput + params.tokenOut = nextOutput + params.fee = nextFee + params.amountSpecified = i256.FromUint256(amountOut) + } +} + +func multiSwapNegative(params SwapParams, numPools int, swapPath string) (*u256.Uint, *u256.Uint) { // firstAmountIn, lastAmountOut + firstAmountIn := u256.Zero() + + swapInfo := []SingleSwapParams{} + + // CALCULATE BACKWARD INFO + currentPoolIndex := numPools + for currentPoolIndex >= 0 { + // Dry-run + amountIn, _ := singleDrySwap( + SingleSwapParams{ + tokenIn: params.tokenIn, + tokenOut: params.tokenOut, + fee: params.fee, + amountSpecified: params.amountSpecified, + }, + ) + + thisSwap := SingleSwapParams{ + tokenIn: params.tokenIn, + tokenOut: params.tokenOut, + fee: params.fee, + amountSpecified: params.amountSpecified, + } + swapInfo = append(swapInfo, thisSwap) + + if currentPoolIndex == 0 { + break + } + currentPoolIndex-- + + nextInput, nextOutput, nextFee := getDataForMultiPath(swapPath, currentPoolIndex) + + intAmountIn := i256.FromUint256(amountIn) + params.tokenIn = nextInput + params.tokenOut = nextOutput + params.fee = nextFee + params.amountSpecified = i256.Zero().Neg(intAmountIn) + } + + // PROCESS FORWARD INFO + currentPoolIndex = len(swapInfo) - 1 + payer := getPrevAddr() // first payer ≈ user + for currentPoolIndex >= 0 { + var recipient std.Address + if currentPoolIndex == 0 { + recipient = params.recipient // params.recipient ≈ user + } else { + recipient = consts.ROUTER_ADDR + } + + amountIn, amountOut := swapInner( + swapInfo[currentPoolIndex].amountSpecified, + recipient, + u256.Zero(), + SwapCallbackData{ + swapInfo[currentPoolIndex].tokenIn, + swapInfo[currentPoolIndex].tokenOut, + swapInfo[currentPoolIndex].fee, + payer, + }, + ) + + // save route's first hop's amountIn to check whether crossed limit or not + if currentPoolIndex == len(swapInfo)-1 { + firstAmountIn = amountIn + } + + if currentPoolIndex == 0 { + return firstAmountIn, amountOut + } + + swapInfo[currentPoolIndex-1].amountSpecified = i256.FromUint256(amountOut) + payer = consts.ROUTER_ADDR + currentPoolIndex-- + } + return firstAmountIn, u256.Zero() +} + +func multiDrySwap(params SwapParams, currentPoolIndex, numPool int, swapPath string) (*u256.Uint, *u256.Uint) { // firstAmountIn, lastAmountOut + firstAmountIn := u256.Zero() + + payer := getPrevAddr() // user + + for { + currentPoolIndex++ + + amountIn, amountOut := swapDryInner( + params.amountSpecified, + u256.Zero(), + SwapCallbackData{ + params.tokenIn, + params.tokenOut, + params.fee, + payer, + }, + ) + + if currentPoolIndex == 1 { + firstAmountIn = amountIn + } + + if currentPoolIndex >= numPool { + return firstAmountIn, amountOut + } + + payer = consts.ROUTER_ADDR + + nextInput, nextOutput, nextFee := getDataForMultiPath(swapPath, currentPoolIndex) + params.tokenIn = nextInput + params.tokenOut = nextOutput + params.fee = nextFee + params.amountSpecified = i256.FromUint256(amountOut) + } + +} + +func multiDrySwapNegative(params SwapParams, currentPoolIndex int, swapPath string) (*u256.Uint, *u256.Uint) { + firstAmountIn := u256.Zero() + payer := consts.ROUTER_ADDR + + for { + amountIn, amountOut := swapDryInner( + params.amountSpecified, + u256.Zero(), + SwapCallbackData{ + params.tokenIn, + params.tokenOut, + params.fee, + payer, + }, + ) + + if currentPoolIndex == 0 { + firstAmountIn = amountIn + } + + currentPoolIndex-- + + if currentPoolIndex == -1 { + return firstAmountIn, amountOut + } + + nextInput, nextOutput, nextFee := getDataForMultiPath(swapPath, currentPoolIndex) + intAmountIn := i256.FromUint256(amountIn) + + params.amountSpecified = i256.Zero().Neg(intAmountIn) + params.tokenIn = nextInput + params.tokenOut = nextOutput + params.fee = nextFee + } +} diff --git a/contract/r/gnoswap/router/swap_multi_test.gno b/contract/r/gnoswap/router/swap_multi_test.gno new file mode 100644 index 000000000..c252db27e --- /dev/null +++ b/contract/r/gnoswap/router/swap_multi_test.gno @@ -0,0 +1,136 @@ +package router + +import ( + "std" + "testing" + + i256 "gno.land/p/gnoswap/int256" + + "gno.land/p/gnoswap/consts" + + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + "gno.land/r/onbloc/qux" + + "gno.land/p/demo/uassert" +) + +func TestMultiSwap(t *testing.T) { + user1Realm := std.NewUserRealm(user1Addr) + + tests := []struct { + name string + setupFn func(t *testing.T) + params SwapParams + currentPoolIndex int + numPools int + swapPath string + expectedFirstIn string + expectedLastOut string + expectError bool + }{ + { + name: "single hop swap BAR -> BAZ", + setupFn: func(t *testing.T) { + CreatePoolWithoutFee(t) + MakeMintPositionWithoutFee(t) + + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + TokenFaucet(t, barPath, user1Addr) + }, + params: SwapParams{ + tokenIn: barPath, + tokenOut: bazPath, + fee: 3000, + recipient: alice, + amountSpecified: i256.MustFromDecimal("100"), + }, + currentPoolIndex: 0, + numPools: 1, + swapPath: "", + expectedFirstIn: "100", + expectedLastOut: "98", + expectError: false, + }, + { + name: "multi hop swap (BAR -> BAZ -> QUX)", + setupFn: func(t *testing.T) { + // BAR -> BAZ + CreatePoolWithoutFee(t) + MakeMintPositionWithoutFee(t) + + // BAZ -> QUX + CreateSecondPoolWithoutFee(t) + MakeSecondMintPositionWithoutFee(t) + + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + qux.Approve(consts.ROUTER_ADDR, maxApprove) + TokenFaucet(t, barPath, user1Addr) + }, + params: SwapParams{ + tokenIn: barPath, + tokenOut: bazPath, + fee: 3000, + recipient: alice, + amountSpecified: i256.MustFromDecimal("100"), + }, + currentPoolIndex: 0, + numPools: 2, + swapPath: "gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:3000*POOL*gno.land/r/onbloc/baz:gno.land/r/onbloc/qux:3000", + expectedFirstIn: "100", + expectedLastOut: "96", + expectError: false, + }, + { + name: "multi hop swap with exact output", + setupFn: func(t *testing.T) { + // BAR -> BAZ -> QUX + CreatePoolWithoutFee(t) + MakeMintPositionWithoutFee(t) + CreateSecondPoolWithoutFee(t) + MakeSecondMintPositionWithoutFee(t) + + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + qux.Approve(consts.ROUTER_ADDR, maxApprove) + TokenFaucet(t, barPath, user1Addr) + }, + params: SwapParams{ + tokenIn: barPath, + tokenOut: bazPath, + fee: 3000, + recipient: alice, + amountSpecified: i256.MustFromDecimal("-96"), + }, + currentPoolIndex: 0, + numPools: 2, + swapPath: "gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:3000*POOL*gno.land/r/onbloc/baz:gno.land/r/onbloc/qux:3000", + expectedFirstIn: "98", + expectedLastOut: "94", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupFn != nil { + tt.setupFn(t) + } + + firstAmountIn, lastAmountOut := multiSwap( + tt.params, + tt.currentPoolIndex, + tt.numPools, + tt.swapPath, + ) + + uassert.Equal(t, firstAmountIn.ToString(), tt.expectedFirstIn) + uassert.Equal(t, lastAmountOut.ToString(), tt.expectedLastOut) + }) + } +} diff --git a/contract/r/gnoswap/router/swap_single.gno b/contract/r/gnoswap/router/swap_single.gno new file mode 100644 index 000000000..87b743612 --- /dev/null +++ b/contract/r/gnoswap/router/swap_single.gno @@ -0,0 +1,77 @@ +package router + +import ( + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/common" +) + +// singleSwap execute a swap within a single pool using the provided parameters. +// It processes a token swap within two assets using a specific fee tier and +// automatically sets the recipient to the caller's address. +// +// Parameters: +// - params: `SingleSwapParams` containing the swap configuration +// - tokenIn: Address of the token being spent +// - tokenOut: Address of the token being received +// - fee: Fee tier of the pool in basis points +// - amountSpecified: Amount specified for the swap (positive for exact input, negative for exact output) +// +// Returns: +// - *u256.Uint: Total amount of input tokens used +// - *u256.Uint: Total amount of output tokens received +// +// The function uses swapInner for the core swap logic and sets the proce limit to 0, +// allowing the swap to execute at any price point within slippage bounds. +func singleSwap(params SingleSwapParams) (*u256.Uint, *u256.Uint) { // amountIn, amountOut + common.MustRegistered(params.tokenIn) + common.MustRegistered(params.tokenOut) + + amountIn, amountOut := swapInner( + params.amountSpecified, + getPrevAddr(), // if single swap => user will recieve + u256.Zero(), // sqrtPriceLimitX96 + SwapCallbackData{ + params.tokenIn, + params.tokenOut, + params.fee, + getPrevAddr(), // payer ==> msg.sender, + }, + ) + + return amountIn, amountOut +} + +// singleDrySwap simulates a single-token swap operation without executing it, +// returning the calculated input and output token amounts. +// +// Parameters: +// - params: A `SingleSwapParams` structure containing the following fields: +// - `tokenIn`: The address of the input token. +// - `tokenOut`: The address of the output token. +// - `amountSpecified`: The amount specified for the swap (input or output amount depending on the swap direction). +// - `fee`: The fee rate applied to the swap. +// +// Returns: +// - *u256.Uint: The calculated amount of the input token required for the swap (`amountIn`). +// - *u256.Uint: The calculated amount of the output token received from the swap (`amountOut`). +// +// Notes: +// - This function performs a simulation and does not alter the state or execute the actual swap. +func singleDrySwap(params SingleSwapParams) (*u256.Uint, *u256.Uint) { + common.MustRegistered(params.tokenIn) + common.MustRegistered(params.tokenOut) + + amountIn, amountOut := swapDryInner( + params.amountSpecified, + u256.Zero(), + SwapCallbackData{ + params.tokenIn, + params.tokenOut, + params.fee, + getPrevAddr(), + }, + ) + + return amountIn, amountOut +} diff --git a/contract/r/gnoswap/router/swap_single_test.gno b/contract/r/gnoswap/router/swap_single_test.gno new file mode 100644 index 000000000..2c3227391 --- /dev/null +++ b/contract/r/gnoswap/router/swap_single_test.gno @@ -0,0 +1,84 @@ +package router + +import ( + "std" + "testing" + + i256 "gno.land/p/gnoswap/int256" + + "gno.land/p/gnoswap/consts" + + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + + "gno.land/p/demo/uassert" +) + +func TestSingleSwap(t *testing.T) { + user1Realm := std.NewUserRealm(user1Addr) + + tests := []struct { + name string + setupFn func(t *testing.T) + params SingleSwapParams + expectedIn string + expectedOut string + expectError bool + }{ + { + name: "exact input swap BAR -> BAZ", + setupFn: func(t *testing.T) { + CreatePoolWithoutFee(t) + MakeMintPositionWithoutFee(t) + + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + TokenFaucet(t, barPath, user1Addr) + }, + params: SingleSwapParams{ + tokenIn: barPath, + tokenOut: bazPath, + fee: 3000, + amountSpecified: i256.MustFromDecimal("100"), + }, + expectedIn: "100", + expectedOut: "98", + expectError: false, + }, + { + name: "exact output swap BAR -> BAZ", + setupFn: func(t *testing.T) { + CreatePoolWithoutFee(t) + MakeMintPositionWithoutFee(t) + + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + TokenFaucet(t, barPath, user1Addr) + }, + params: SingleSwapParams{ + tokenIn: barPath, + tokenOut: bazPath, + fee: 3000, + amountSpecified: i256.MustFromDecimal("-98"), + }, + expectedIn: "100", + expectedOut: "98", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupFn != nil { + tt.setupFn(t) + } + + amountIn, amountOut := singleSwap(tt.params) + + uassert.Equal(t, amountIn.ToString(), tt.expectedIn) + uassert.Equal(t, amountOut.ToString(), tt.expectedOut) + }) + } +} diff --git a/contract/r/gnoswap/router/type.gno b/contract/r/gnoswap/router/type.gno new file mode 100644 index 000000000..ce1763e1b --- /dev/null +++ b/contract/r/gnoswap/router/type.gno @@ -0,0 +1,152 @@ +package router + +import ( + "std" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/demo/ufmt" +) + +type SwapType string + +const ( + // ExactIn represents a swap type where the input amount is exact and the output amount may vary. + // Used when a user wants to swap a specific amount of input tokens. + ExactIn SwapType = "EXACT_IN" + + // ExactOut represents a swap type where the output amount is exact and the input amount may vary. + // Used when a user wants to swap a specific amount of output tokens. + ExactOut SwapType = "EXACT_OUT" +) + +// trySwapTypeFromStr attempts to convert a string into a `SwapType`. +// +// This function validates and converts string representations of swap types +// into their corresponding `SwapType` enum values. +func trySwapTypeFromStr(swapType string) (SwapType, error) { + switch swapType { + case "EXACT_IN": + return ExactIn, nil + case "EXACT_OUT": + return ExactOut, nil + default: + return "", ufmt.Errorf("unknown swapType: expected ExactIn or ExactOut, got %s", swapType) + } +} + +func (s SwapType) String() string { + switch s { + case ExactIn: + return "EXACT_IN" + case ExactOut: + return "EXACT_OUT" + default: + return "" + } +} + +// SingleSwapParams contains parameters for executing a single pool swap. +// It represents the simplest form of swap that occurs within a single liquidity pool. +type SingleSwapParams struct { + tokenIn string // token to spend + tokenOut string // token to receive + fee uint32 // fee of the pool used to swap + + // Amount specified for the swap: + // - Positive: exact input amount (tokenIn) + // - Negative: exact output amount (tokenOut) + amountSpecified *i256.Int +} + +// SwapParams contains parameters for executing a multi-hop swap opration. +// It extends the `SingleSwapParams` with recipient information for more complex swaps +// that involve multiple pools. +type SwapParams struct { + tokenIn string // token to being spent + tokenOut string // token to being received + fee uint32 // fee of the pool used to swap + recipient std.Address // address to receive the token + + // Amount specified for the swap: + // - Positive: exact input amount (tokenIn) + // - Negative: exact output amount (tokenOut) + amountSpecified *i256.Int +} + +// newSwapParams creates a new `SwapParams` instance with the provided parameters. +// +// Parameters: +// - tokenIn: Address of the token being spent +// - tokenOut: Address of the token being received +// - fee: Fee tier of the pool in basis points +// - recipient: Address that will receive the output tokens +// - amountSpecified: Amount specified for the swap (positive for exact input, negative for exact output) +// +// Returns: +// - *SwapParams: new `SwapParams` instance +func newSwapParams(tokenIn, tokenOut string, fee uint32, recipient std.Address, amountSpecified *i256.Int) *SwapParams { + return &SwapParams{ + tokenIn: tokenIn, + tokenOut: tokenOut, + fee: fee, + recipient: recipient, + amountSpecified: amountSpecified, + } +} + +// SwapResult encapsulates the outcome of a swap operation +type SwapResult struct { + AmountIn *u256.Uint + AmountOut *u256.Uint + Routes []string + Quotes []string + AmountSpecified *i256.Int +} + +// SwapCallbackData contains the callback data required for swap execution. +// This type is used to pass necessary information during the swap callback process, +// ensuring proper token transfers and pool data updates. +type SwapCallbackData struct { + tokenIn string // token to spend + tokenOut string // token to receive + fee uint32 // fee of the pool used to swap + payer std.Address // address to spend the token +} + +type ExactInParams struct { + BaseSwapParams + AmountIn string + AmountOutMin string +} + +func NewExactInParams( + baseParams BaseSwapParams, + amountIn string, + amountOutMin string, +) ExactInParams { + return ExactInParams{ + BaseSwapParams: baseParams, + AmountIn: amountIn, + AmountOutMin: amountOutMin, + } +} + +type ExactOutParams struct { + BaseSwapParams + AmountOut string + AmountInMax string +} + +func NewExactOutParams( + baseParams BaseSwapParams, + amountOut string, + amountInMax string, +) ExactOutParams { + return ExactOutParams{ + BaseSwapParams: baseParams, + AmountOut: amountOut, + AmountInMax: amountInMax, + } +} diff --git a/contract/r/gnoswap/router/type_test.gno b/contract/r/gnoswap/router/type_test.gno new file mode 100644 index 000000000..43ff3b470 --- /dev/null +++ b/contract/r/gnoswap/router/type_test.gno @@ -0,0 +1,55 @@ +package router + +import ( + "testing" + + "gno.land/p/demo/uassert" +) + +func TestTrySwapTypeFromStr(t *testing.T) { + tests := []struct { + name string + input string + want SwapType + wantErr bool + }{ + { + name: "valid EXACT_IN", + input: "EXACT_IN", + want: ExactIn, + wantErr: false, + }, + { + name: "valid EXACT_OUT", + input: "EXACT_OUT", + want: ExactOut, + wantErr: false, + }, + { + name: "invalid empty string", + input: "", + want: "", + wantErr: true, + }, + { + name: "invalid swap type", + input: "INVALID_TYPE", + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := trySwapTypeFromStr(tt.input) + + if !tt.wantErr { + uassert.NoError(t, err) + } + + if got != tt.want { + t.Errorf("trySwapTypeFromStr() = %v, want %v", got, tt.want) + } + }) + } +} \ No newline at end of file diff --git a/contract/r/gnoswap/router/utils.gno b/contract/r/gnoswap/router/utils.gno new file mode 100644 index 000000000..45782bcb8 --- /dev/null +++ b/contract/r/gnoswap/router/utils.gno @@ -0,0 +1,217 @@ +package router + +import ( + "bytes" + "std" + "strconv" + "strings" + "time" + + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/common" + + i256 "gno.land/p/gnoswap/int256" +) + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +// assertDirectCallOnly panics if the caller is not the user. +func assertDirectCallOnly() { + if common.GetLimitCaller() && std.PrevRealm().PkgPath() != "" { + panic(addDetailToError(errNoPermission, "only user can call this function")) + } +} + +// assertHopsInRange panics if the number of hops is not in the range 1~3. +func assertHopsInRange(hops int) { + if hops < 1 || hops > 3 { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("number of hops(%d) must be 1~3", hops), + )) + } +} + +// getDataForSinglePath extracts token0, token1, and fee from a single path. +func getDataForSinglePath(poolPath string) (string, string, uint32) { + poolPathSplit, err := common.Split(poolPath, ":", 3) + if err != nil { + panic(addDetailToError( + errInvalidPoolPath, + ufmt.Sprintf("len(poolPathSplit) != 3, poolPath: %s", poolPath), + )) + } + + token0 := poolPathSplit[0] + token1 := poolPathSplit[1] + fee, _ := strconv.Atoi(poolPathSplit[2]) + + return token0, token1, uint32(fee) +} + +// getDataForMultiPath extracts token0, token1, and fee from a multi path. +func getDataForMultiPath(possiblePath string, poolIdx int) (string, string, uint32) { + pools := strings.Split(possiblePath, "*POOL*") + + var token0, token1 string + var fee uint32 + + switch poolIdx { + case 0: + token0, token1, fee = getDataForSinglePath(pools[0]) + case 1: + token0, token1, fee = getDataForSinglePath(pools[1]) + case 2: + token0, token1, fee = getDataForSinglePath(pools[2]) + default: + return "", "", uint32(0) + } + + return token0, token1, fee +} + +// isStringInStringArr checks if a string is in a string array. +func isStringInStringArr(arr []string, str string) bool { + for _, a := range arr { + if a == str { + return true + } + } + return false +} + +// removeStringFromStringArr removes a string from a string array. +func removeStringFromStringArr(arr []string, str string) []string { + for i, a := range arr { + if a == str { + return append(arr[:i], arr[i+1:]...) + } + } + return arr +} + +// min returns the smaller of two integers. +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// i256Min returns the smaller of two i256.Int. +func i256Min(x, y *i256.Int) *i256.Int { + if x.Lt(y) { + return x + } + return y +} + +// i256Max returns the larger of two i256.Int. +func i256Max(x, y *i256.Int) *i256.Int { + if x.Gt(y) { + return x + } + return y +} + +// getPrevRealm returns object of the previous realm. +func prevRealm() string { + return std.PrevRealm().PkgPath() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// getPrev returns the address and package path of the previous realm. +func getPrev() (string, string) { + prev := std.PrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// checkDeadline checks if the deadline is expired. +// If the deadline is expired, it panics. +// The deadline is expired if the current time is greater than the deadline. +// Input: +// - deadline: the deadline to check +func checkDeadline(deadline int64) { + now := time.Now().Unix() + if now > deadline { + panic(addDetailToError( + errExpired, + ufmt.Sprintf("transaction too old, now(%d) > deadline(%d)", now, deadline), + )) + } +} + +// splitSingleChar splits a string by a single character separator. +// +// This function is optimized for splitting strings with a single-byte separator. +// Unlike `strings.Split`, it: +// 1. Performs direct byte comparison instead of substring matching +// 2. Avoids additional string allocations by using slicing +// 3. Makes only one allocation for the result slice +// +// The main differences from `strings.Split` are: +// - Only works with single-byte separators +// - More memory efficient as it doesn't need to handle multi-byte separators +// - Faster for small to medium strings due to simpler byte comparison +// +// Performance: +// - Up to 5x faster than `strings.Split` for small strings (in Go) +// - For gno (run test with `-print-runtime-metrics` option): +// | Function | Cycles | Allocations +// |------------------|------------------|--------------| +// | strings.Split | 1.1M | 808.1K | +// | splitSingleChar | 1.0M | 730.4K | +// - Uses zero allocations except for the initial result slice +// - Most effective for strings under 1KB with simple single-byte delimiters +// (* This test result was measured without the `uassert` package) +// +// Parameters: +// +// s (string): source string to split +// sep (byte): single byte separator to split on +// +// Returns: +// +// []string: slice containing the split string parts +func splitSingleChar(s string, sep byte) []string { + l := len(s) + if l == 0 { + return []string{""} + } + + result := make([]string, 0, bytes.Count([]byte(s), []byte{sep})+1) + start := 0 + for i := 0; i < l; i++ { + if s[i] == sep { + result = append(result, s[start:i]) + start = i + 1 + } + } + result = append(result, s[start:]) + return result +} + +func formatUint(v interface{}) string { + switch v := v.(type) { + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} diff --git a/contract/r/gnoswap/router/utils_test.gno b/contract/r/gnoswap/router/utils_test.gno new file mode 100644 index 000000000..f83a78c76 --- /dev/null +++ b/contract/r/gnoswap/router/utils_test.gno @@ -0,0 +1,205 @@ +package router + +import ( + "strings" + "testing" + + "gno.land/p/demo/uassert" +) + +func TestGetDataForSinglePath(t *testing.T) { + tests := []struct { + name string + input string + wantToken0 string + wantToken1 string + wantFee int + shouldPanic bool + }{ + { + name: "valid path", + input: "tokenA:tokenB:500", + wantToken0: "tokenA", + wantToken1: "tokenB", + wantFee: int(500), + shouldPanic: false, + }, + { + name: "invalid path format", + input: "tokenA:tokenB", + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if (r != nil) != tt.shouldPanic { + t.Errorf("getDataForSinglePath() panic = %v, shouldPanic = %v", r != nil, tt.shouldPanic) + } + }() + + token0, token1, fee := getDataForSinglePath(tt.input) + if !tt.shouldPanic { + if token0 != tt.wantToken0 { + t.Errorf("token0 = %v, want %v", token0, tt.wantToken0) + } + if token1 != tt.wantToken1 { + t.Errorf("token1 = %v, want %v", token1, tt.wantToken1) + } + if int(fee) != tt.wantFee { + t.Errorf("fee = %v, want %v", fee, tt.wantFee) + } + } + }) + } +} + +func TestGetDataForMultiPath(t *testing.T) { + tests := []struct { + name string + input string + poolIdx int + wantToken0 string + wantToken1 string + wantFee uint32 + }{ + { + name: "first pool", + input: "tokenA:tokenB:500*POOL*tokenB:tokenC:3000*POOL*tokenC:tokenD:10000", + poolIdx: 0, + wantToken0: "tokenA", + wantToken1: "tokenB", + wantFee: 500, + }, + { + name: "second pool", + input: "tokenA:tokenB:500*POOL*tokenB:tokenC:3000*POOL*tokenC:tokenD:10000", + poolIdx: 1, + wantToken0: "tokenB", + wantToken1: "tokenC", + wantFee: 3000, + }, + { + name: "third pool", + input: "tokenA:tokenB:500*POOL*tokenB:tokenC:3000*POOL*tokenC:tokenD:10000", + poolIdx: 2, + wantToken0: "tokenC", + wantToken1: "tokenD", + wantFee: 10000, + }, + { + name: "invalid pool index", + input: "tokenA:tokenB:500*POOL*tokenB:tokenC:3000", + poolIdx: 3, + wantToken0: "", + wantToken1: "", + wantFee: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token0, token1, fee := getDataForMultiPath(tt.input, tt.poolIdx) + + if token0 != tt.wantToken0 { + t.Errorf("token0 = %v, want %v", token0, tt.wantToken0) + } + if token1 != tt.wantToken1 { + t.Errorf("token1 = %v, want %v", token1, tt.wantToken1) + } + if fee != tt.wantFee { + t.Errorf("fee = %v, want %v", fee, tt.wantFee) + } + }) + } +} + +func TestSplitSingleChar(t *testing.T) { + testCases := []struct { + name string + input string + sep byte + expected []string + }{ + { + name: "plain split", + input: "a,b,c", + sep: ',', + expected: []string{"a", "b", "c"}, + }, + { + name: "empty string", + input: "", + sep: ',', + expected: []string{""}, + }, + { + name: "no separator", + input: "abc", + sep: ',', + expected: []string{"abc"}, + }, + { + name: "consecutive separators", + input: "a,,b,,c", + sep: ',', + expected: []string{"a", "", "b", "", "c"}, + }, + { + name: "separator at the beginning and end", + input: ",a,b,c,", + sep: ',', + expected: []string{"", "a", "b", "c", ""}, + }, + { + name: "space separator", + input: "a b c", + sep: ' ', + expected: []string{"a", "b", "c"}, + }, + { + name: "single character string", + input: "a", + sep: ',', + expected: []string{"a"}, + }, + { + name: "only separators", + input: ",,,,", + sep: ',', + expected: []string{"", "", "", "", ""}, + }, + { + name: "unicode characters", + input: "한글,English,日本語", + sep: ',', + expected: []string{"한글", "English", "日本語"}, + }, + { + name: "special characters", + input: "!@#$,%^&*,()_+", + sep: ',', + expected: []string{"!@#$", "%^&*", "()_+"}, + }, + { + name: "routes path", + input: "gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:500*POOL*gno.land/r/onbloc/baz:gno.land/r/onbloc/qux:500,gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:500*POOL*gno.land/r/onbloc/baz:gno.land/r/onbloc/qux:500", + sep: ',', + expected: []string{"gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:500*POOL*gno.land/r/onbloc/baz:gno.land/r/onbloc/qux:500", "gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:500*POOL*gno.land/r/onbloc/baz:gno.land/r/onbloc/qux:500"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := splitSingleChar(tc.input, tc.sep) + + uassert.Equal(t, len(result), len(tc.expected)) + + for i := 0; i < len(tc.expected); i++ { + uassert.Equal(t, result[i], tc.expected[i]) + } + }) + } +} diff --git a/contract/r/gnoswap/router/wrap_unwrap.gno b/contract/r/gnoswap/router/wrap_unwrap.gno new file mode 100644 index 000000000..77824d66d --- /dev/null +++ b/contract/r/gnoswap/router/wrap_unwrap.gno @@ -0,0 +1,52 @@ +package router + +import ( + "std" + + "gno.land/r/demo/wugnot" + + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +func wrap(ugnotAmount uint64) { + if ugnotAmount <= 0 { + panic(addDetailToError( + errWrapUnwrap, + ufmt.Sprintf("wrap.gno__wrap() || cannot wrap 0 ugnot"), + )) + } + + if ugnotAmount < consts.UGNOT_MIN_DEPOSIT_TO_WRAP { + panic(addDetailToError( + errWugnotMinimum, + ufmt.Sprintf("wrap.gno__wrap() || amount(%d) < minimum(%d)", ugnotAmount, consts.UGNOT_MIN_DEPOSIT_TO_WRAP), + )) + } + + // WRAP IT + wugnotAddr := std.DerivePkgAddr(consts.WRAPPED_WUGNOT) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(consts.ROUTER_ADDR, wugnotAddr, std.Coins{{"ugnot", int64(ugnotAmount)}}) + wugnot.Deposit() // ROUTER HAS WUGNOT + + // SEND WUGNOT: ROUTER -> USER + wugnot.Transfer(common.AddrToUser(std.PrevRealm().Addr()), ugnotAmount) +} + +func unwrap(wugnotAmount uint64) { + if wugnotAmount == 0 { + return + } + + // SEND WUGNOT: USER -> ROUTER + wugnot.TransferFrom(common.AddrToUser(std.PrevRealm().Addr()), common.AddrToUser(consts.ROUTER_ADDR), wugnotAmount) + + // UNWRAP IT + wugnot.Withdraw(wugnotAmount) + + // SEND GNOT: ROUTER -> USER + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(consts.ROUTER_ADDR, std.PrevRealm().Addr(), std.Coins{{"ugnot", int64(wugnotAmount)}}) +} diff --git a/contract/r/gnoswap/staker/_helper_test.gno b/contract/r/gnoswap/staker/_helper_test.gno new file mode 100644 index 000000000..30cf9bf50 --- /dev/null +++ b/contract/r/gnoswap/staker/_helper_test.gno @@ -0,0 +1,735 @@ +package staker + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/json" + "gno.land/p/demo/testutils" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + en "gno.land/r/gnoswap/v1/emission" + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gnft" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + "gno.land/r/onbloc/foo" + "gno.land/r/onbloc/obl" + "gno.land/r/onbloc/qux" +) + +const ( + ugnotDenom string = "ugnot" + ugnotPath string = "ugnot" + wugnotPath string = "gno.land/r/demo/wugnot" + gnsPath string = "gno.land/r/gnoswap/v1/gns" + barPath string = "gno.land/r/onbloc/bar" + bazPath string = "gno.land/r/onbloc/baz" + fooPath string = "gno.land/r/onbloc/foo" + oblPath string = "gno.land/r/onbloc/obl" + quxPath string = "gno.land/r/onbloc/qux" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + maxApprove uint64 = 18446744073709551615 + max_timeout int64 = 9999999999 +) + +const ( + // define addresses to use in tests + addr01 = testutils.TestAddress("addr01") + addr02 = testutils.TestAddress("addr02") +) + +var ( + adminAddr = consts.ADMIN + admin = consts.ADMIN + alice = testutils.TestAddress("alice") + pool = consts.POOL_ADDR + protocolFee = consts.PROTOCOL_FEE_ADDR + adminRealm = std.NewUserRealm(admin) + posRealm = std.NewCodeRealm(consts.POSITION_PATH) + rouRealm = std.NewCodeRealm(consts.ROUTER_PATH) + + // addresses used in tests + addrUsedInTest = []std.Address{addr01, addr02} +) + +func CreatePool(t *testing.T, + token0 string, + token1 string, + fee uint32, + sqrtPriceX96 string, + caller std.Address, +) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(caller)) + poolPath := pl.GetPoolPath(token0, token1, fee) + if !pl.DoesPoolPathExist(poolPath) { + pl.CreatePool(token0, token1, fee, sqrtPriceX96) + } +} + +func LPTokenStake(t *testing.T, owner std.Address, tokenId uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) +} + +func LPTokenUnStake(t *testing.T, owner std.Address, tokenId uint64, unwrap bool) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) +} + +func InitialisePoolTest(t *testing.T) { + t.Helper() + + ugnotFaucet(t, admin, 100_000_000_000_000) + ugnotDeposit(t, admin, 100_000_000_000_000) + TokenFaucet(t, gnsPath, admin) + + std.TestSetOrigCaller(admin) + TokenApprove(t, gnsPath, admin, pool, maxApprove) + poolPath := pl.GetPoolPath(wugnotPath, gnsPath, fee3000) + if !pl.DoesPoolPathExist(poolPath) { + pl.CreatePool(wugnotPath, gnsPath, fee3000, "79228162514264337593543950336") + } + + std.TestSetOrigCaller(alice) + TokenFaucet(t, wugnotPath, alice) + TokenFaucet(t, gnsPath, alice) + TokenApprove(t, wugnotPath, alice, pool, uint64(1000)) + TokenApprove(t, gnsPath, alice, pool, uint64(1000)) + MintPosition(t, + wugnotPath, + gnsPath, + fee3000, + int32(1020), + int32(5040), + "1000", + "1000", + "0", + "0", + max_timeout, + alice, + alice, + ) +} + +func CreateSecondPoolWithoutFee(t *testing.T) { + std.TestSetRealm(adminRealm) + pl.SetPoolCreationFeeByAdmin(0) + + CreatePool(t, + barPath, + bazPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), + admin, + ) +} + +func MakeMintPositionWithoutFee(t *testing.T) (uint64, string, string, string) { + t.Helper() + + // make actual data to test resetting not only position's state but also pool's state + std.TestSetRealm(adminRealm) + + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, bazPath, admin, pool, consts.UINT64_MAX) + + // mint position + return pn.Mint( + barPath, + bazPath, + fee3000, + -887220, + 887220, + "50000", + "50000", + "0", + "0", + max_timeout, + admin, + admin, + ) +} + +func TokenFaucet(t *testing.T, tokenPath string, to std.Address) { + t.Helper() + std.TestSetOrigCaller(admin) + defaultAmount := uint64(5_000_000_000) + + switch tokenPath { + case wugnotPath: + wugnotTransfer(t, to, defaultAmount) + case gnsPath: + gnsTransfer(t, to, defaultAmount) + case barPath: + barTransfer(t, to, defaultAmount) + case bazPath: + bazTransfer(t, to, defaultAmount) + case fooPath: + fooTransfer(t, to, defaultAmount) + case oblPath: + oblTransfer(t, to, defaultAmount) + case quxPath: + quxTransfer(t, to, defaultAmount) + default: + panic("token not found") + } +} + +func TokenBalance(t *testing.T, tokenPath string, owner std.Address) uint64 { + t.Helper() + switch tokenPath { + case wugnotPath: + return wugnot.BalanceOf(common.AddrToUser(owner)) + case gnsPath: + return gns.BalanceOf(owner) + case barPath: + return bar.BalanceOf(owner) + case bazPath: + return baz.BalanceOf(owner) + case fooPath: + return foo.BalanceOf(owner) + case oblPath: + return obl.BalanceOf(owner) + case quxPath: + return qux.BalanceOf(owner) + default: + panic("token not found") + } +} + +func TokenAllowance(t *testing.T, tokenPath string, owner, spender std.Address) uint64 { + t.Helper() + switch tokenPath { + case wugnotPath: + return wugnot.Allowance(common.AddrToUser(owner), common.AddrToUser(spender)) + case gnsPath: + return gns.Allowance(owner, spender) + case barPath: + return bar.Allowance(owner, spender) + case bazPath: + return baz.Allowance(owner, spender) + case fooPath: + return foo.Allowance(owner, spender) + case oblPath: + return obl.Allowance(owner, spender) + case quxPath: + return qux.Allowance(owner, spender) + default: + panic("token not found") + } +} + +func TokenApprove(t *testing.T, tokenPath string, owner, spender std.Address, amount uint64) { + t.Helper() + switch tokenPath { + case wugnotPath: + wugnotApprove(t, owner, spender, amount) + case gnsPath: + gnsApprove(t, owner, spender, amount) + case barPath: + barApprove(t, owner, spender, amount) + case bazPath: + bazApprove(t, owner, spender, amount) + case fooPath: + fooApprove(t, owner, spender, amount) + case oblPath: + oblApprove(t, owner, spender, amount) + case quxPath: + quxApprove(t, owner, spender, amount) + default: + panic("token not found") + } +} + +func MintPosition(t *testing.T, + token0 string, + token1 string, + fee uint32, + tickLower int32, + tickUpper int32, + amount0Desired string, // *u256.Uint + amount1Desired string, // *u256.Uint + amount0Min string, // *u256.Uint + amount1Min string, // *u256.Uint + deadline int64, + mintTo std.Address, + caller std.Address, +) (uint64, string, string, string) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(caller)) + + return pn.Mint( + token0, + token1, + fee, + tickLower, + tickUpper, + amount0Desired, + amount1Desired, + amount0Min, + amount1Min, + deadline, + mintTo, + caller) +} + +func wugnotApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + wugnot.Approve(common.AddrToUser(spender), amount) +} + +func gnsApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + gns.Approve(spender, amount) +} + +func barApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + bar.Approve(spender, amount) +} + +func bazApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + baz.Approve(spender, amount) +} + +func fooApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + foo.Approve(spender, amount) +} + +func oblApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + obl.Approve(spender, amount) +} + +func quxApprove(t *testing.T, owner, spender std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(owner)) + qux.Approve(spender, amount) +} + +func wugnotTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + wugnot.Transfer(common.AddrToUser(to), amount) +} + +func gnsTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + gns.Transfer(to, amount) +} + +func barTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + bar.Transfer(to, amount) +} + +func bazTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + baz.Transfer(to, amount) +} + +func fooTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + foo.Transfer(to, amount) +} + +func oblTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + obl.Transfer(to, amount) +} + +func quxTransfer(t *testing.T, to std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(admin)) + qux.Transfer(to, amount) +} + +// ---------------------------------------------------------------------------- +// ugnot + +func ugnotTransfer(t *testing.T, from, to std.Address, amount uint64) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(from)) + std.TestSetOrigSend(std.Coins{{ugnotDenom, int64(amount)}}, nil) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(from, to, std.Coins{{ugnotDenom, int64(amount)}}) +} + +func ugnotBalanceOf(t *testing.T, addr std.Address) uint64 { + t.Helper() + + banker := std.GetBanker(std.BankerTypeRealmIssue) + coins := banker.GetCoins(addr) + if len(coins) == 0 { + return 0 + } + + return uint64(coins.AmountOf(ugnotDenom)) +} + +func ugnotMint(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.IssueCoin(addr, denom, amount) + std.TestIssueCoins(addr, std.Coins{{denom, int64(amount)}}) +} + +func ugnotBurn(t *testing.T, addr std.Address, denom string, amount int64) { + t.Helper() + banker := std.GetBanker(std.BankerTypeRealmIssue) + banker.RemoveCoin(addr, denom, amount) +} + +func ugnotFaucet(t *testing.T, to std.Address, amount uint64) { + t.Helper() + faucetAddress := admin + std.TestSetOrigCaller(faucetAddress) + + if ugnotBalanceOf(t, faucetAddress) < amount { + newCoins := std.Coins{{ugnotDenom, int64(amount)}} + ugnotMint(t, faucetAddress, newCoins[0].Denom, newCoins[0].Amount) + std.TestSetOrigSend(newCoins, nil) + } + ugnotTransfer(t, faucetAddress, to, amount) +} + +func ugnotDeposit(t *testing.T, addr std.Address, amount uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(addr)) + wugnotAddr := consts.WUGNOT_ADDR + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(addr, wugnotAddr, std.Coins{{ugnotDenom, int64(amount)}}) + wugnot.Deposit() +} + +// burnAllNFT burns all NFTs +func burnAllNFT(t *testing.T) { + t.Helper() + + std.TestSetRealm(std.NewCodeRealm(consts.POSITION_PATH)) + for i := uint64(1); i <= gnft.TotalSupply(); i++ { + gnft.Burn(tid(i)) + } +} + +func deletePoolTier(t *testing.T, poolPath string) { + t.Helper() + if poolTier != nil { + poolTier.changeTier(std.GetHeight(), pools, poolPath, 0) + } else { + panic("poolTier is nil") + } +} + +func addPoolTier(t *testing.T, poolPath string, tier uint64) { + t.Helper() + if poolTier != nil { + if pools != nil { + pools.GetOrCreate(poolPath) + } else { + panic(addDetailToError( + errPoolNotFound, ufmt.Sprintf("unknown error - pools is nil"))) + } + poolTier.changeTier(std.GetHeight(), pools, poolPath, tier) + } else { + poolTier = NewPoolTier(pools, std.GetHeight(), poolPath, en.GetEmission, en.GetHalvingBlocksInRange) + pools.GetOrCreate(poolPath) // must update pools tree + poolTier.changeTier(std.GetHeight(), pools, poolPath, tier) + } +} + +func changeWarmup(t *testing.T, index int, blockDuration int64) { + modifyWarmup(index, blockDuration) +} + +func getNumPoolTiers(t *testing.T) (uint64, uint64, uint64) { + counts := poolTier.CurrentAllTierCounts() + tier1Num := counts[1] + tier2Num := counts[2] + tier3Num := counts[3] + + return uint64(tier1Num), uint64(tier2Num), uint64(tier3Num) +} + +type gnsBalanceTracker struct { + height int64 + stakerBalance uint64 + devOpsBalance uint64 + communityPoolBalance uint64 + govStakerBalance uint64 + protocolFeeBalance uint64 + callerBalance uint64 +} + +func gnsBalanceCheck(t *testing.T, beforeBalance gnsBalanceTracker, printChange bool) gnsBalanceTracker { + t.Helper() + + caller := std.PrevRealm().Addr() + height := std.GetHeight() + stakerBalance := gns.BalanceOf(consts.STAKER_ADDR) + devOpsBalance := gns.BalanceOf(consts.DEV_OPS) + communityPoolBalance := gns.BalanceOf(consts.COMMUNITY_POOL_ADDR) + govStakerBalance := gns.BalanceOf(consts.GOV_STAKER_ADDR) + protocolFeeBalance := gns.BalanceOf(consts.PROTOCOL_FEE_ADDR) + callerBalance := gns.BalanceOf(caller) + + return gnsBalanceTracker{ + height: height, + stakerBalance: stakerBalance, + devOpsBalance: devOpsBalance, + communityPoolBalance: communityPoolBalance, + govStakerBalance: govStakerBalance, + protocolFeeBalance: protocolFeeBalance, + callerBalance: callerBalance, + } +} + +// returns true if actual is within 0.0001% of expected +func isInErrorRange(expected uint64, actual uint64) bool { + lowerBound := expected * 999999 / 1000000 + upperBound := expected * 1000001 / 1000000 + return actual >= lowerBound && actual <= upperBound +} + +func getPrintInfo(t *testing.T) string { + en.MintAndDistributeGns() + + emissionDebug := ApiEmissionDebugInfo{} + emissionDebug.Height = std.GetHeight() + emissionDebug.Time = time.Now().Unix() + emissionDebug.GnsStaker = gns.BalanceOf(consts.STAKER_ADDR) + emissionDebug.GnsDevOps = gns.BalanceOf(consts.DEV_OPS) + emissionDebug.GnsCommunityPool = gns.BalanceOf(consts.COMMUNITY_POOL_ADDR) + emissionDebug.GnsGovStaker = gns.BalanceOf(consts.GOV_STAKER_ADDR) + emissionDebug.GnsProtocolFee = gns.BalanceOf(consts.PROTOCOL_FEE_ADDR) + emissionDebug.GnsADMIN = gns.BalanceOf(consts.ADMIN) + + poolTiers := make(map[string]uint64) + pools.tree.Iterate("", "", func(poolPath string, iPool interface{}) bool { + poolTier := poolTier.CurrentTier(poolPath) + poolTiers[poolPath] = poolTier + return false + }) + + for poolPath, poolTier := range poolTiers { + pool := ApiEmissionDebugPool{} + pool.PoolPath = poolPath + pool.Tier = poolTier + + numTier1, numTier2, numTier3 := getNumPoolTiers(t) + if poolTier == 1 { + pool.NumPoolInSameTier = numTier1 + } else if poolTier == 2 { + pool.NumPoolInSameTier = numTier2 + } else if poolTier == 3 { + pool.NumPoolInSameTier = numTier3 + } + + deposits.tree.Iterate("", "", func(tokenIdStr string, value interface{}) bool { + tokenId := DecodeUint(tokenIdStr) + deposit := value.(*Deposit) + + if deposit.targetPoolPath == poolPath { + position := ApiEmissionDebugPosition{} + position.LpTokenId = tokenId + position.StakedHeight = deposit.stakeHeight + position.StakedTimestamp = deposit.stakeTimestamp + position.StakedDuration = emissionDebug.Height - deposit.stakeHeight + + position.Ratio = getRewardRatio(t, position.StakedDuration) + pool.Position = append(pool.Position, position) + } + return false + }) + emissionDebug.Pool = append(emissionDebug.Pool, pool) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "height": json.StringNode("", ufmt.Sprintf("%d", emissionDebug.Height)), + "time": json.StringNode("", ufmt.Sprintf("%d", emissionDebug.Time)), + "gns": json.ObjectNode("", map[string]*json.Node{ + "staker": json.StringNode("", ufmt.Sprintf("%d", emissionDebug.GnsStaker)), + "devOps": json.StringNode("", ufmt.Sprintf("%d", emissionDebug.GnsDevOps)), + "communityPool": json.StringNode("", ufmt.Sprintf("%d", emissionDebug.GnsCommunityPool)), + "govStaker": json.StringNode("", ufmt.Sprintf("%d", emissionDebug.GnsGovStaker)), + "protocolFee": json.StringNode("", ufmt.Sprintf("%d", emissionDebug.GnsProtocolFee)), + "GnoswapAdmin": json.StringNode("", ufmt.Sprintf("%d", emissionDebug.GnsADMIN)), + }), + "pool": json.ArrayNode("", makePoolsNode(t, emissionDebug.Pool)), + }) + + b, err := json.Marshal(node) + if err != nil { + panic("JSON MARSHAL ERROR") + } + + return string(b) +} + +type ApiEmissionDebugInfo struct { + Height int64 `json:"height"` + Time int64 `json:"time"` + GnsStaker uint64 `json:"gnsStaker"` + GnsDevOps uint64 `json:"gnsDevOps"` + GnsCommunityPool uint64 `json:"gnsCommunityPool"` + GnsGovStaker uint64 `json:"gnsGovStaker"` + GnsProtocolFee uint64 `json:"gnsProtocolFee"` + GnsADMIN uint64 `json:"gnsADMIN"` + Pool []ApiEmissionDebugPool `json:"pool"` +} + +type ApiEmissionDebugPool struct { + PoolPath string `json:"poolPath"` + Tier uint64 `json:"tier"` + NumPoolInSameTier uint64 `json:"numPoolInSameTier"` + PoolReward uint64 `json:"poolReward"` + Position []ApiEmissionDebugPosition `json:"position"` +} + +type ApiEmissionDebugPosition struct { + LpTokenId uint64 `json:"lpTokenId"` + StakedHeight int64 `json:"stakedHeight"` + StakedTimestamp int64 `json:"stakedTimestamp"` + StakedDuration int64 `json:"stakedDuration"` + FullAmount uint64 `json:"fullAmount"` + Ratio uint64 `json:"ratio"` + RatioAmount uint64 `json:"ratioAmount"` +} + +func getRewardRatio(t *testing.T, height int64) uint64 { + t.Helper() + warmups := InstantiateWarmup(height) + + for _, warmup := range warmups { + if height < warmup.NextWarmupHeight { + return warmup.WarmupRatio + } + } + + // passed all warmup-periods + return 100 +} + +func makePoolsNode(t *testing.T, emissionPool []ApiEmissionDebugPool) []*json.Node { + poolNodes := make([]*json.Node, 0) + + poolTiers := make(map[string]uint64) + pools.tree.Iterate("", "", func(poolPath string, iPool interface{}) bool { + poolTier := poolTier.CurrentTier(poolPath) + poolTiers[poolPath] = poolTier + return false + }) + + for poolPath, poolTier := range poolTiers { + numTier1, numTier2, numTier3 := getNumPoolTiers(t) + numPoolSameTier := uint64(0) + tier := poolTier + if tier == 1 { + numPoolSameTier = numTier1 + } else if tier == 2 { + numPoolSameTier = numTier2 + } else if tier == 3 { + numPoolSameTier = numTier3 + } + + poolNodes = append(poolNodes, json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", poolPath), + "tier": json.StringNode("tier", ufmt.Sprintf("%d", tier)), + "numPoolSameTier": json.StringNode("numPoolSameTier", ufmt.Sprintf("%d", numPoolSameTier)), + "position": json.ArrayNode("", makePositionsNode(t, poolPath)), + })) + } + + return poolNodes +} + +func makePositionsNode(t *testing.T, poolPath string) []*json.Node { + positions := make([]*json.Node, 0) + + deposits.tree.Iterate("", "", func(tokenIdStr string, value interface{}) bool { + tokenId := DecodeUint(tokenIdStr) + deposit := value.(*Deposit) + + if deposit.targetPoolPath == poolPath { + stakedDuration := std.GetHeight() - deposit.stakeHeight + ratio := getRewardRatio(t, stakedDuration) + + rewardByWarmup := calcPositionRewardByWarmups(std.GetHeight(), tokenId) + if len(rewardByWarmup) != 4 { + panic("len(rewardByWarmup) != 4") + } + + reward30 := rewardByWarmup[0].Internal + penalty30 := rewardByWarmup[0].InternalPenalty + full30 := reward30 + penalty30 + + reward50 := rewardByWarmup[1].Internal + penalty50 := rewardByWarmup[1].InternalPenalty + full50 := reward50 + penalty50 + + reward70 := rewardByWarmup[2].Internal + penalty70 := rewardByWarmup[2].InternalPenalty + full70 := reward70 + penalty70 + + reward100 := rewardByWarmup[3].Internal + penalty100 := rewardByWarmup[3].InternalPenalty + full100 := reward100 + penalty100 + + fullAmount := full30 + full50 + full70 + full100 + warmUpAmount := reward30 + reward50 + reward70 + reward100 + + positions = append(positions, json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.StringNode("lpTokenId", ufmt.Sprintf("%d", tokenId)), + "stakedHeight": json.StringNode("stakedHeight", ufmt.Sprintf("%d", deposit.stakeHeight)), + "stakedTimestamp": json.StringNode("stakedTimestamp", ufmt.Sprintf("%d", deposit.stakeTimestamp)), + "stakedDuration": json.StringNode("stakedDuration", ufmt.Sprintf("%d", stakedDuration)), + "fullAmount": json.StringNode("fullAmount", ufmt.Sprintf("%d", fullAmount)), + "ratio": json.StringNode("ratio", ufmt.Sprintf("%d", ratio)), + "warmUpAmount": json.StringNode("warmUpAmount", ufmt.Sprintf("%d", warmUpAmount)), + "full30": json.StringNode("full30", ufmt.Sprintf("%d", full30)), + "give30": json.StringNode("give30", ufmt.Sprintf("%d", reward30)), + "penalty30": json.StringNode("penalty30", ufmt.Sprintf("%d", penalty30)), + "full50": json.StringNode("full50", ufmt.Sprintf("%d", full50)), + "give50": json.StringNode("give50", ufmt.Sprintf("%d", reward50)), + "penalty50": json.StringNode("penalty50", ufmt.Sprintf("%d", penalty50)), + "full70": json.StringNode("full70", ufmt.Sprintf("%d", full70)), + "give70": json.StringNode("give70", ufmt.Sprintf("%d", reward70)), + "penalty70": json.StringNode("penalty70", ufmt.Sprintf("%d", penalty70)), + "full100": json.StringNode("full100", ufmt.Sprintf("%d", full100)), + "give100": json.StringNode("give100", ufmt.Sprintf("%d", reward100)), + "penalty100": json.StringNode("penalty100", ufmt.Sprintf("%d", penalty100)), + })) + } + + return false + }) + + return positions +} diff --git a/contract/r/gnoswap/staker/api.gno b/contract/r/gnoswap/staker/api.gno new file mode 100644 index 000000000..15446fada --- /dev/null +++ b/contract/r/gnoswap/staker/api.gno @@ -0,0 +1,1046 @@ +package staker + +import ( + "math" + "std" + "strconv" + "time" + + "gno.land/p/demo/json" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + en "gno.land/r/gnoswap/v1/emission" +) + +type RewardToken struct { + PoolPath string `json:"poolPath"` + RewardsTokenList []string `json:"rewardsTokenList"` +} + +type ApiExternalIncentive struct { + IncentiveId string `json:"incentiveId"` + PoolPath string `json:"poolPath"` + RewardToken string `json:"rewardToken"` + RewardAmount uint64 `json:"rewardAmount"` + RewardLeft uint64 `json:"rewardLeft"` + StartTimestamp int64 `json:"startTimestamp"` + EndTimestamp int64 `json:"endTimestamp"` + Active bool `json:"active"` + Refundee string `json:"refundee"` + CreatedHeight int64 `json:"createdHeight"` + DepositGnsAmount uint64 `json:"depositGnsAmount"` +} + +type ApiInternalIncentive struct { + PoolPath string `json:"poolPath"` + Tier uint64 `json:"tier"` + StartTimestamp int64 `json:"startTimestamp"` + RewardPerBlock string `json:"rewardPerBlock"` +} + +func ApiGetRewardTokens() string { + en.MintAndDistributeGns() + + rewardTokens := []RewardToken{} + + pools.IterateAll(func(key string, pool *Pool) bool { + thisPoolRewardTokens := []string{} + + // HANDLE INTERNAL + if poolTier.IsInternallyIncentivizedPool(pool.poolPath) { + thisPoolRewardTokens = append(thisPoolRewardTokens, consts.GNS_PATH) + } + + // HANDLE EXTERNAL + if pool.IsExternallyIncentivizedPool() { + pool.incentives.byTime.Iterate("", "", func(key string, value interface{}) bool { + ictv := value.(*ExternalIncentive) + if ictv.RewardToken() == "" { + return false + } + thisPoolRewardTokens = append(thisPoolRewardTokens, ictv.RewardToken()) + return false + }) + } + + if len(thisPoolRewardTokens) == 0 { + return false + } + + rewardTokens = append(rewardTokens, RewardToken{ + PoolPath: pool.poolPath, + RewardsTokenList: thisPoolRewardTokens, + }) + return false + }) + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, rewardToken := range rewardTokens { + _rewardTokenNode := json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", rewardToken.PoolPath), + "tokens": json.ArrayNode("tokens", makeRewardTokensArray(rewardToken.RewardsTokenList)), + }) + responses.AppendArray(_rewardTokenNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetRewardTokensByPoolPath(targetPoolPath string) string { + en.MintAndDistributeGns() + + rewardTokens := []RewardToken{} + + pool, ok := pools.Get(targetPoolPath) + if !ok { + return "" + } + + thisPoolRewardTokens := []string{} + + // HANDLE INTERNAL + if poolTier.IsInternallyIncentivizedPool(pool.poolPath) { + thisPoolRewardTokens = append(thisPoolRewardTokens, consts.GNS_PATH) + } + + // HANDLE EXTERNAL + if pool.IsExternallyIncentivizedPool() { + pool.incentives.byTime.Iterate("", "", func(key string, value interface{}) bool { + ictv := value.(*ExternalIncentive) + if ictv.RewardToken() == "" { + return false + } + thisPoolRewardTokens = append(thisPoolRewardTokens, ictv.RewardToken()) + return false + }) + } + + rewardTokens = append(rewardTokens, RewardToken{ + PoolPath: pool.poolPath, + RewardsTokenList: thisPoolRewardTokens, + }) + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, rewardToken := range rewardTokens { + _rewardTokenNode := json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", rewardToken.PoolPath), + "tokens": json.ArrayNode("tokens", makeRewardTokensArray(rewardToken.RewardsTokenList)), + }) + responses.AppendArray(_rewardTokenNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetExternalIncentives() string { + en.MintAndDistributeGns() + + apiExternalIncentives := []ApiExternalIncentive{} + + pools.tree.Iterate("", "", func(key string, value interface{}) bool { + pool := value.(*Pool) + pool.incentives.byTime.Iterate("", "", func(key string, value interface{}) bool { + ictv := value.(*ExternalIncentive) + apiExternalIncentives = append(apiExternalIncentives, ApiExternalIncentive{ + IncentiveId: ictv.incentiveId, + PoolPath: ictv.targetPoolPath, + RewardToken: ictv.rewardToken, + RewardAmount: ictv.rewardAmount, + RewardLeft: ictv.rewardLeft, + StartTimestamp: ictv.startTimestamp, + EndTimestamp: ictv.endTimestamp, + Refundee: ictv.refundee.String(), + CreatedHeight: ictv.createdHeight, + DepositGnsAmount: ictv.depositGnsAmount, + }) + return false + }) + return false + }) + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, incentive := range apiExternalIncentives { + active := false + if time.Now().Unix() >= incentive.StartTimestamp && time.Now().Unix() <= incentive.EndTimestamp { + active = true + } + + _incentiveNode := json.ObjectNode("", map[string]*json.Node{ + "incentiveId": json.StringNode("incentiveId", incentive.IncentiveId), + "poolPath": json.StringNode("poolPath", incentive.PoolPath), + "rewardToken": json.StringNode("rewardToken", incentive.RewardToken), + "rewardAmount": json.StringNode("rewardAmount", strconv.FormatUint(incentive.RewardAmount, 10)), + "rewardLeft": json.StringNode("rewardLeft", strconv.FormatUint(incentive.RewardLeft, 10)), + "startTimestamp": json.NumberNode("startTimestamp", float64(incentive.StartTimestamp)), + "endTimestamp": json.NumberNode("endTimestamp", float64(incentive.EndTimestamp)), + "active": json.BoolNode("active", active), + "refundee": json.StringNode("refundee", incentive.Refundee), + "createdHeight": json.NumberNode("createdHeight", float64(incentive.CreatedHeight)), + "depositGnsAmount": json.NumberNode("depositGnsAmount", float64(incentive.DepositGnsAmount)), + }) + responses.AppendArray(_incentiveNode) + } + + // RETURN + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetExternalIncentiveById(poolPath, incentiveId string) string { + en.MintAndDistributeGns() + + apiExternalIncentives := []ApiExternalIncentive{} + + pool, ok := pools.Get(poolPath) + if !ok { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("pool(%s) not found", poolPath), + )) + } + + incentive, exist := pool.incentives.GetByIncentiveId(incentiveId) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("incentive(%s) not found", incentiveId), + )) + } + + apiExternalIncentives = append(apiExternalIncentives, ApiExternalIncentive{ + IncentiveId: incentiveId, + PoolPath: poolPath, + RewardToken: incentive.rewardToken, + RewardAmount: incentive.rewardAmount, + RewardLeft: incentive.rewardLeft, + StartTimestamp: incentive.startTimestamp, + EndTimestamp: incentive.endTimestamp, + Refundee: incentive.refundee.String(), + CreatedHeight: incentive.createdHeight, + DepositGnsAmount: incentive.depositGnsAmount, + }) + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, incentive := range apiExternalIncentives { + active := false + if time.Now().Unix() >= incentive.StartTimestamp && time.Now().Unix() <= incentive.EndTimestamp { + active = true + } + + _incentiveNode := json.ObjectNode("", map[string]*json.Node{ + "incentiveId": json.StringNode("incentiveId", incentive.IncentiveId), + "poolPath": json.StringNode("poolPath", incentive.PoolPath), + "rewardToken": json.StringNode("rewardToken", incentive.RewardToken), + "rewardAmount": json.StringNode("rewardAmount", strconv.FormatUint(incentive.RewardAmount, 10)), + "rewardLeft": json.StringNode("rewardLeft", strconv.FormatUint(incentive.RewardLeft, 10)), + "startTimestamp": json.NumberNode("startTimestamp", float64(incentive.StartTimestamp)), + "endTimestamp": json.NumberNode("endTimestamp", float64(incentive.EndTimestamp)), + "active": json.BoolNode("active", active), + "refundee": json.StringNode("refundee", incentive.Refundee), + "createdHeight": json.NumberNode("createdHeight", float64(incentive.CreatedHeight)), + "depositGnsAmount": json.NumberNode("depositGnsAmount", float64(incentive.DepositGnsAmount)), + }) + responses.AppendArray(_incentiveNode) + } + + // RETURN + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetExternalIncentivesByPoolPath(targetPoolPath string) string { + en.MintAndDistributeGns() + + apiExternalIncentives := []ApiExternalIncentive{} + + pool, ok := pools.Get(targetPoolPath) + if !ok { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("pool(%s) not found", targetPoolPath), + )) + } + + pool.incentives.byTime.Iterate("", "", func(key string, value interface{}) bool { + incentive := value.(*ExternalIncentive) + if incentive.targetPoolPath != targetPoolPath { + return false + } + + apiExternalIncentives = append(apiExternalIncentives, ApiExternalIncentive{ + IncentiveId: incentive.incentiveId, + PoolPath: incentive.targetPoolPath, + RewardToken: incentive.rewardToken, + RewardAmount: incentive.rewardAmount, + RewardLeft: incentive.rewardLeft, + StartTimestamp: incentive.startTimestamp, + EndTimestamp: incentive.endTimestamp, + Refundee: incentive.refundee.String(), + CreatedHeight: incentive.createdHeight, + DepositGnsAmount: incentive.depositGnsAmount, + }) + return false + }) + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, incentive := range apiExternalIncentives { + active := false + if time.Now().Unix() >= incentive.StartTimestamp && time.Now().Unix() <= incentive.EndTimestamp { + active = true + } + + _incentiveNode := json.ObjectNode("", map[string]*json.Node{ + "incentiveId": json.StringNode("incentiveId", incentive.IncentiveId), + "poolPath": json.StringNode("poolPath", incentive.PoolPath), + "rewardToken": json.StringNode("rewardToken", incentive.RewardToken), + "rewardAmount": json.StringNode("rewardAmount", strconv.FormatUint(incentive.RewardAmount, 10)), + "rewardLeft": json.StringNode("rewardLeft", strconv.FormatUint(incentive.RewardLeft, 10)), + "startTimestamp": json.NumberNode("startTimestamp", float64(incentive.StartTimestamp)), + "endTimestamp": json.NumberNode("endTimestamp", float64(incentive.EndTimestamp)), + "active": json.BoolNode("active", active), + "refundee": json.StringNode("refundee", incentive.Refundee), + "createdHeight": json.NumberNode("createdHeight", float64(incentive.CreatedHeight)), + "depositGnsAmount": json.NumberNode("depositGnsAmount", float64(incentive.DepositGnsAmount)), + }) + responses.AppendArray(_incentiveNode) + } + + // RETURN + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetExternalIncentivesByRewardTokenPath(rewardTokenPath string) string { + en.MintAndDistributeGns() + + apiExternalIncentives := []ApiExternalIncentive{} + + pools.tree.Iterate("", "", func(key string, value interface{}) bool { + pool := value.(*Pool) + pool.incentives.byTime.Iterate("", "", func(key string, value interface{}) bool { + incentive := value.(*ExternalIncentive) + if incentive.rewardToken != rewardTokenPath { + return false + } + + apiExternalIncentives = append(apiExternalIncentives, ApiExternalIncentive{ + IncentiveId: incentive.incentiveId, + PoolPath: pool.poolPath, + RewardToken: incentive.rewardToken, + RewardAmount: incentive.rewardAmount, + RewardLeft: incentive.rewardLeft, + StartTimestamp: incentive.startTimestamp, + EndTimestamp: incentive.endTimestamp, + Refundee: incentive.refundee.String(), + CreatedHeight: incentive.createdHeight, + DepositGnsAmount: incentive.depositGnsAmount, + }) + return false + }) + return false + }) + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, incentive := range apiExternalIncentives { + active := false + if time.Now().Unix() >= incentive.StartTimestamp && time.Now().Unix() <= incentive.EndTimestamp { + active = true + } + + _incentiveNode := json.ObjectNode("", map[string]*json.Node{ + "incentiveId": json.StringNode("incentiveId", incentive.IncentiveId), + "poolPath": json.StringNode("poolPath", incentive.PoolPath), + "rewardToken": json.StringNode("rewardToken", incentive.RewardToken), + "rewardAmount": json.StringNode("rewardAmount", strconv.FormatUint(incentive.RewardAmount, 10)), + "rewardLeft": json.StringNode("rewardLeft", strconv.FormatUint(incentive.RewardLeft, 10)), + "startTimestamp": json.NumberNode("startTimestamp", float64(incentive.StartTimestamp)), + "endTimestamp": json.NumberNode("endTimestamp", float64(incentive.EndTimestamp)), + "active": json.BoolNode("active", active), + "refundee": json.StringNode("refundee", incentive.Refundee), + "createdHeight": json.NumberNode("createdHeight", float64(incentive.CreatedHeight)), + "depositGnsAmount": json.NumberNode("depositGnsAmount", float64(incentive.DepositGnsAmount)), + }) + responses.AppendArray(_incentiveNode) + } + + // RETURN + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetInternalIncentives() string { + en.MintAndDistributeGns() + + apiInternalIncentives := []ApiInternalIncentive{} + + poolTier.membership.Iterate("", "", func(key string, value interface{}) bool { + poolPath := key + internalTier := value.(uint64) + apiInternalIncentives = append(apiInternalIncentives, ApiInternalIncentive{ + PoolPath: poolPath, + Tier: internalTier, + RewardPerBlock: calculateInternalRewardPerBlockByPoolPath(poolPath), + }) + return false + }) + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, incentive := range apiInternalIncentives { + _incentiveNode := json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", incentive.PoolPath), + "rewardToken": json.StringNode("rewardToken", consts.GNS_PATH), + "tier": json.NumberNode("tier", float64(incentive.Tier)), + "rewardPerBlock": json.StringNode("rewardPerBlock", incentive.RewardPerBlock), + }) + responses.AppendArray(_incentiveNode) + } + + // RETURN + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetInternalIncentivesByPoolPath(targetPoolPath string) string { + en.MintAndDistributeGns() + + apiInternalIncentives := []ApiInternalIncentive{} + + tier := poolTier.CurrentTier(targetPoolPath) + if tier == 0 { + return "" + } + + apiInternalIncentives = append(apiInternalIncentives, ApiInternalIncentive{ + PoolPath: targetPoolPath, + Tier: tier, + RewardPerBlock: calculateInternalRewardPerBlockByPoolPath(targetPoolPath), + }) + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, incentive := range apiInternalIncentives { + _incentiveNode := json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", incentive.PoolPath), + "rewardToken": json.StringNode("rewardToken", consts.GNS_PATH), + "tier": json.NumberNode("tier", float64(incentive.Tier)), + "rewardPerBlock": json.StringNode("rewardPerBlock", incentive.RewardPerBlock), + }) + responses.AppendArray(_incentiveNode) + } + + // RETURN + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetInternalIncentivesByTiers(targetTier uint64) string { + en.MintAndDistributeGns() + + apiInternalIncentives := []ApiInternalIncentive{} + + poolTier.membership.Iterate("", "", func(key string, value interface{}) bool { + poolPath := key + internalTier := value.(uint64) + if internalTier != targetTier { + return false + } + + apiInternalIncentives = append(apiInternalIncentives, ApiInternalIncentive{ + PoolPath: poolPath, + Tier: internalTier, + RewardPerBlock: calculateInternalRewardPerBlockByPoolPath(poolPath), + }) + + return false + }) + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, incentive := range apiInternalIncentives { + _incentiveNode := json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", incentive.PoolPath), + "rewardToken": json.StringNode("rewardToken", consts.GNS_PATH), + "tier": json.NumberNode("tier", float64(incentive.Tier)), + "rewardPerBlock": json.StringNode("rewardPerBlock", incentive.RewardPerBlock), + }) + responses.AppendArray(_incentiveNode) + } + + // RETURN + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func makeRewardTokensArray(rewardsTokenList []string) []*json.Node { + rewardsTokenArray := make([]*json.Node, len(rewardsTokenList)) + for i, rewardToken := range rewardsTokenList { + rewardsTokenArray[i] = json.StringNode("", rewardToken) + } + return rewardsTokenArray +} + +func calculateInternalRewardPerBlockByPoolPath(poolPath string) string { + reward := poolTier.CurrentRewardPerPool(poolPath) + return ufmt.Sprintf("%d", reward) +} + +// LpTokenReward represents the rewards associated with a specific LP token +type LpTokenReward struct { + LpTokenId uint64 `json:"lpTokenId"` // The ID of the LP token + Address string `json:"address"` // The address associated with the LP token + Rewards []ApiReward `json:"rewards"` +} + +// Reward represents a single reward for a staked LP token +type ApiReward struct { + IncentiveType string `json:"incentiveType"` // The type of incentive (INTERNAL or EXTERNAL) + IncentiveId string `json:"incentiveId"` // The unique identifier of the incentive + TargetPoolPath string `json:"targetPoolPath"` // The path of the target pool for the reward + RewardTokenPath string `json:"rewardTokenPath"` // The pathe of the reward token + RewardTokenAmount uint64 `json:"rewardTokenAmount"` // The amount of the reward token + StakeTimestamp int64 `json:"stakeTimestamp"` // The timestamp when the LP token was staked + StakeHeight int64 `json:"stakeHeight"` // The block height when the LP token was staked + IncentiveStart int64 `json:"incentiveStart"` // The timestamp when the incentive started +} + +// Stake represents a single stake +type ApiStake struct { + TokenId uint64 `json:"tokenId"` // The ID of the staked LP token + Owner std.Address `json:"owner"` // The address of the owner of the staked LP token + NumberOfStakes uint64 `json:"numberOfStakes"` // The number of times this LP token has been staked + StakeTimestamp int64 `json:"stakeTimestamp"` // The timestamp when the LP token was staked + StakeHeight int64 `json:"stakeHeight"` // The block height when the LP token was staked + TargetPoolPath string `json:"targetPoolPath"` // The path of the target pool for the stake +} + +// ResponseQueryBase contains basic information about a query response. +type ResponseQueryBase struct { + Height int64 `json:"height"` // The block height at the time of the query + Timestamp int64 `json:"timestamp"` // The timestamp at the time of the query +} + +// ResponseApiGetRewards represents the API response for getting rewards. +type ResponseApiGetRewards struct { + Stat ResponseQueryBase `json:"stat"` // Basic query information + Response []LpTokenReward `json:"response"` // A slice of LpTokenReward structs +} + +// ResponseApiGetRewardByLpTokenId represents the API response for getting rewards for a specific LP token. +type ResponseApiGetRewardByLpTokenId struct { + Stat ResponseQueryBase `json:"stat"` // Basic query information + Response LpTokenReward `json:"response"` // The LpTokenReward for the specified LP token +} + +// ResponseApiGetStakes represents the API response for getting stakes. +type ResponseApiGetStakes struct { + Stat ResponseQueryBase `json:"stat"` // Basic query information + Response []ApiStake `json:"response"` // A slice of Stake structs +} + +func ApiGetRewardsByLpTokenId(targetLpTokenId uint64) string { + en.MintAndDistributeGns() + + deposit := deposits.Get(targetLpTokenId) + + reward := calcPositionReward(std.GetHeight(), targetLpTokenId) + + rewards := []ApiReward{} + + if reward.Internal > 0 { + rewards = append(rewards, ApiReward{ + IncentiveType: "INTERNAL", + IncentiveId: "", + TargetPoolPath: deposit.targetPoolPath, + RewardTokenPath: consts.GNS_PATH, + RewardTokenAmount: reward.Internal, + StakeTimestamp: deposit.stakeTimestamp, + StakeHeight: deposit.stakeHeight, + IncentiveStart: deposit.stakeTimestamp, + }) + } + + for incentiveId, externalReward := range reward.External { + rewards = append(rewards, ApiReward{ + IncentiveType: "EXTERNAL", + IncentiveId: incentiveId, + TargetPoolPath: deposit.targetPoolPath, + RewardTokenPath: consts.GNS_PATH, + RewardTokenAmount: externalReward, + StakeTimestamp: deposit.stakeTimestamp, + StakeHeight: deposit.stakeHeight, + IncentiveStart: deposit.stakeTimestamp, + }) + } + + qb := ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + } + + r := ResponseApiGetRewards{ + Stat: qb, + Response: []LpTokenReward{ + { + LpTokenId: targetLpTokenId, + Address: deposit.owner.String(), + Rewards: rewards, + }, + }, + } + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, reward := range r.Response { + _rewardNode := json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(reward.LpTokenId)), + "address": json.StringNode("address", reward.Address), + "rewards": json.ArrayNode("rewards", makeRewardsArray(reward.Rewards)), + }) + responses.AppendArray(_rewardNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetRewardsByAddress(targetAddress string) string { + en.MintAndDistributeGns() + + lpTokenRewards := []LpTokenReward{} + + stakers.IterateAll(std.Address(targetAddress), func(tokenId uint64, deposit *Deposit) bool { + rewards := []ApiReward{} + + reward := calcPositionReward(std.GetHeight(), tokenId) + + // get internal gns reward + if reward.Internal > 0 { + rewards = append(rewards, ApiReward{ + IncentiveType: "INTERNAL", + IncentiveId: "", + TargetPoolPath: deposit.targetPoolPath, + RewardTokenPath: consts.GNS_PATH, + RewardTokenAmount: reward.Internal, + StakeTimestamp: deposit.stakeTimestamp, + StakeHeight: deposit.stakeHeight, + IncentiveStart: deposit.stakeTimestamp, + }) + } + + // find all external reward list for poolPath which lpTokenId is staked + for incentiveId, externalReward := range reward.External { + rewards = append(rewards, ApiReward{ + IncentiveType: "EXTERNAL", + IncentiveId: incentiveId, + TargetPoolPath: deposit.targetPoolPath, + RewardTokenPath: consts.GNS_PATH, + RewardTokenAmount: externalReward, + StakeTimestamp: deposit.stakeTimestamp, + StakeHeight: deposit.stakeHeight, + IncentiveStart: deposit.stakeTimestamp, + }) + } + lpTokenReward := LpTokenReward{ + LpTokenId: tokenId, + Address: deposit.owner.String(), + Rewards: rewards, + } + lpTokenRewards = append(lpTokenRewards, lpTokenReward) + + return false + }) + + qb := ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + } + + r := ResponseApiGetRewards{ + Stat: qb, + Response: lpTokenRewards, + } + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, reward := range r.Response { + _rewardNode := json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(reward.LpTokenId)), + "address": json.StringNode("address", reward.Address), + "rewards": json.ArrayNode("rewards", makeRewardsArray(reward.Rewards)), + }) + responses.AppendArray(_rewardNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetStakes() string { + en.MintAndDistributeGns() + + stakes := []ApiStake{} + deposits.Iterate(0, math.MaxUint64, func(tokenId uint64, deposit *Deposit) bool { + stakes = append(stakes, ApiStake{ + TokenId: tokenId, + Owner: deposit.owner, + StakeTimestamp: deposit.stakeTimestamp, + StakeHeight: deposit.stakeHeight, + TargetPoolPath: deposit.targetPoolPath, + }) + return false + }) + + qb := ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + } + + r := ResponseApiGetStakes{ + Stat: qb, + Response: stakes, + } + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, stake := range r.Response { + _stakeNode := json.ObjectNode("", map[string]*json.Node{ + "tokenId": json.NumberNode("tokenId", float64(stake.TokenId)), + "owner": json.StringNode("owner", stake.Owner.String()), + "stakeTimestamp": json.NumberNode("stakeTimestamp", float64(stake.StakeTimestamp)), + "stakeHeight": json.NumberNode("stakeHeight", float64(stake.StakeHeight)), + "targetPoolPath": json.StringNode("targetPoolPath", stake.TargetPoolPath), + }) + responses.AppendArray(_stakeNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetStakesByLpTokenId(targetLpTokenId uint64) string { + en.MintAndDistributeGns() + + stakes := []ApiStake{} + + deposit := deposits.Get(targetLpTokenId) + stakes = append(stakes, ApiStake{ + TokenId: targetLpTokenId, + Owner: deposit.owner, + StakeTimestamp: deposit.stakeTimestamp, + StakeHeight: deposit.stakeHeight, + TargetPoolPath: deposit.targetPoolPath, + }) + + qb := ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + } + + r := ResponseApiGetStakes{ + Stat: qb, + Response: stakes, + } + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, stake := range r.Response { + _stakeNode := json.ObjectNode("", map[string]*json.Node{ + "tokenId": json.NumberNode("tokenId", float64(stake.TokenId)), + "owner": json.StringNode("owner", stake.Owner.String()), + "numberOfStakes": json.NumberNode("numberOfStakes", float64(stake.NumberOfStakes)), + "stakeTimestamp": json.NumberNode("stakeTimestamp", float64(stake.StakeTimestamp)), + "stakeHeight": json.NumberNode("stakeHeight", float64(stake.StakeHeight)), + "targetPoolPath": json.StringNode("targetPoolPath", stake.TargetPoolPath), + }) + responses.AppendArray(_stakeNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +func ApiGetStakesByAddress(targetAddress string) string { + en.MintAndDistributeGns() + + stakes := []ApiStake{} + + stakers.IterateAll(std.Address(targetAddress), func(tokenId uint64, deposit *Deposit) bool { + stakes = append(stakes, ApiStake{ + TokenId: tokenId, + Owner: deposit.owner, + StakeTimestamp: deposit.stakeTimestamp, + StakeHeight: deposit.stakeHeight, + TargetPoolPath: deposit.targetPoolPath, + }) + + return false + }) + + qb := ResponseQueryBase{ + Height: std.GetHeight(), + Timestamp: time.Now().Unix(), + } + + r := ResponseApiGetStakes{ + Stat: qb, + Response: stakes, + } + + // STAT NODE + _stat := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("height", float64(std.GetHeight())), + "timestamp": json.NumberNode("timestamp", float64(time.Now().Unix())), + }) + + // RESPONSE (ARRAY) NODE + responses := json.ArrayNode("", []*json.Node{}) + for _, stake := range r.Response { + _stakeNode := json.ObjectNode("", map[string]*json.Node{ + "tokenId": json.NumberNode("tokenId", float64(stake.TokenId)), + "owner": json.StringNode("owner", stake.Owner.String()), + "stakeTimestamp": json.NumberNode("stakeTimestamp", float64(stake.StakeTimestamp)), + "stakeHeight": json.NumberNode("stakeHeight", float64(stake.StakeHeight)), + "targetPoolPath": json.StringNode("targetPoolPath", stake.TargetPoolPath), + }) + responses.AppendArray(_stakeNode) + } + + node := json.ObjectNode("", map[string]*json.Node{ + "stat": _stat, + "response": responses, + }) + + b, err := json.Marshal(node) + if err != nil { + panic(err.Error()) + } + + return string(b) +} + +// for off chain to check if lpTokenId is staked via RPC +func IsStaked(tokenId uint64) bool { + return deposits.Has(tokenId) +} + +func makeRewardsArray(rewards []ApiReward) []*json.Node { + rewardsArray := make([]*json.Node, len(rewards)) + + for i, reward := range rewards { + rewardsArray[i] = json.ObjectNode("", map[string]*json.Node{ + "incentiveType": json.StringNode("incentiveType", reward.IncentiveType), + "incentiveId": json.StringNode("incentiveId", reward.IncentiveId), + "targetPoolPath": json.StringNode("targetPoolPath", reward.TargetPoolPath), + "rewardTokenPath": json.StringNode("rewardTokenPath", reward.RewardTokenPath), + "rewardTokenAmount": json.NumberNode("rewardTokenAmount", float64(reward.RewardTokenAmount)), + "stakeTimestamp": json.NumberNode("stakeTimestamp", float64(reward.StakeTimestamp)), + "stakeHeight": json.NumberNode("stakeHeight", float64(reward.StakeHeight)), + "incentiveStart": json.NumberNode("incentiveStart", float64(reward.IncentiveStart)), + }) + } + return rewardsArray +} diff --git a/contract/r/gnoswap/staker/calculate_pool_position_reward.gno b/contract/r/gnoswap/staker/calculate_pool_position_reward.gno new file mode 100644 index 000000000..75c044891 --- /dev/null +++ b/contract/r/gnoswap/staker/calculate_pool_position_reward.gno @@ -0,0 +1,166 @@ +package staker + +import ( + "gno.land/p/gnoswap/consts" + + u256 "gno.land/p/gnoswap/uint256" +) + +// Q96 +var _q96 = u256.MustFromDecimal(consts.Q96) + +func isAbleToCalculateEmissionReward(prev int64, current int64) bool { + if prev >= current { + return false + } + return true +} + +// Reward is a struct for storing reward for a position. +// Internal reward is the GNS reward, external reward is the reward for other incentives. +// Penalties are the amount that is deducted from the reward due to the position's warmup. +type Reward struct { + Internal uint64 + InternalPenalty uint64 + External map[string]uint64 // Incentive ID -> TokenAmount + ExternalPenalty map[string]uint64 // Incentive ID -> TokenAmount +} + +// calculate position reward by each warmup +func calcPositionRewardByWarmups(currentHeight int64, tokenId uint64) []Reward { + rewards := CalcPositionReward(CalcPositionRewardParam{ + CurrentHeight: currentHeight, + Deposits: deposits, + Pools: pools, + PoolTier: poolTier, + TokenId: tokenId, + }) + + return rewards +} + +// calculate total position rewards and penalties +func calcPositionReward(currentHeight int64, tokenId uint64) Reward { + rewards := CalcPositionReward(CalcPositionRewardParam{ + CurrentHeight: currentHeight, + Deposits: deposits, + Pools: pools, + PoolTier: poolTier, + TokenId: tokenId, + }) + + internal := uint64(0) + for _, reward := range rewards { + internal += reward.Internal + } + + internalPenalty := uint64(0) + for _, reward := range rewards { + internalPenalty += reward.InternalPenalty + } + + externalReward := make(map[string]uint64) + for _, reward := range rewards { + if reward.External != nil { + for incentive, reward := range reward.External { + externalReward[incentive] += reward + } + } + } + + externalPenalty := make(map[string]uint64) + for _, reward := range rewards { + if reward.ExternalPenalty != nil { + for incentive, penalty := range reward.ExternalPenalty { + externalPenalty[incentive] += penalty + } + } + } + + return Reward{ + Internal: internal, + InternalPenalty: internalPenalty, + External: externalReward, + ExternalPenalty: externalPenalty, + } +} + +// CalcPositionRewardParam is a struct for calculating position reward +type CalcPositionRewardParam struct { + // Environmental variables + CurrentHeight int64 + Deposits *Deposits + Pools *Pools + PoolTier *PoolTier + + // Position variables + TokenId uint64 +} + +func CalcPositionReward(param CalcPositionRewardParam) []Reward { + // cache per-pool rewards in the internal incentive(tiers) + param.PoolTier.cacheReward(param.CurrentHeight, param.Pools) + + deposit := param.Deposits.Get(param.TokenId) + poolPath := deposit.targetPoolPath + + pool, ok := param.Pools.Get(poolPath) + if !ok { + pool = NewPool(poolPath, param.CurrentHeight) + param.Pools.Set(poolPath, pool) + } + + lastCollectHeight := deposit.lastCollectHeight + + // Initializes reward/penalty arrays for rewards and penalties for each warmup + internalRewards := make([]uint64, len(deposit.warmups)) + internalPenalties := make([]uint64, len(deposit.warmups)) + externalRewards := make([]map[string]uint64, len(deposit.warmups)) + externalPenalties := make([]map[string]uint64, len(deposit.warmups)) + + if param.PoolTier.CurrentTier(poolPath) != 0 { + // Internal incentivized pool. + // Calculate reward for each warmup + internalRewards, internalPenalties = pool.RewardStateOf(deposit).CalculateInternalReward(lastCollectHeight, param.CurrentHeight) + } + + // All active incentives + allIncentives := pool.incentives.GetAllInHeights(lastCollectHeight, param.CurrentHeight) + + for i := range externalRewards { + externalRewards[i] = make(map[string]uint64) + externalPenalties[i] = make(map[string]uint64) + } + + for incentiveId, incentive := range allIncentives { + // External incentivized pool. + // Calculate reward for each warmup + externalReward, externalPenalty := pool.RewardStateOf(deposit).CalculateExternalReward(int64(lastCollectHeight), int64(param.CurrentHeight), incentive) + + for i := range externalReward { + externalRewards[i][incentiveId] = externalReward[i] + externalPenalties[i][incentiveId] = externalPenalty[i] + } + } + + rewards := make([]Reward, len(internalRewards)) + for i := range internalRewards { + rewards[i] = Reward{ + Internal: internalRewards[i], + InternalPenalty: internalPenalties[i], + External: externalRewards[i], + ExternalPenalty: externalPenalties[i], + } + } + + return rewards +} + +// calculates internal unclaimable reward for the pool +func ProcessUnClaimableReward(poolPath string, endHeight int64) uint64 { + pool, ok := pools.Get(poolPath) + if !ok { + return 0 + } + return pool.processUnclaimableReward(poolTier, endHeight) +} diff --git a/contract/r/gnoswap/staker/callback.gno b/contract/r/gnoswap/staker/callback.gno new file mode 100644 index 000000000..77fba4884 --- /dev/null +++ b/contract/r/gnoswap/staker/callback.gno @@ -0,0 +1,21 @@ +package staker + +import ( + "std" +) + +// callbackStakerEmissionChange is called by emission when +// - msPerBlock is changed +// - staker emission % is changed +// it does NOT get called in regards of halving(manually handled in cacheReward). +// +// Initially, it is passed to emission in staker.gno#init(). +// +// For the parameter `emission`, which is a per-block emission for staker contract, +// - It first caches the reward until the current block height. +// - Then, it updates the `currentEmission` of the poolTier, +// which will be applied for future blocks thereafter. +func callbackStakerEmissionChange(emission uint64) { + poolTier.cacheReward(std.GetHeight(), pools) + poolTier.currentEmission = emission +} diff --git a/contract/r/gnoswap/staker/doc.gno b/contract/r/gnoswap/staker/doc.gno new file mode 100644 index 000000000..26c04f4f0 --- /dev/null +++ b/contract/r/gnoswap/staker/doc.gno @@ -0,0 +1,3 @@ +// Package staker implements LP token staking functionality with both internal +// and external GNS emissions and external incentive rewards. +package staker diff --git a/contract/r/gnoswap/staker/errors.gno b/contract/r/gnoswap/staker/errors.gno new file mode 100644 index 000000000..b203d29e8 --- /dev/null +++ b/contract/r/gnoswap/staker/errors.gno @@ -0,0 +1,47 @@ +package staker + +import ( + "errors" + + "gno.land/p/demo/ufmt" +) + +var ( + errNoPermission = errors.New("[GNOSWAP-STAKER-001] caller has no permission") + errPoolNotFound = errors.New("[GNOSWAP-STAKER-002] pool not found") + errAlreadyRegistered = errors.New("[GNOSWAP-STAKER-003] already registered token") + errInsufficientReward = errors.New("[GNOSWAP-STAKER-004] insufficient reward") + errWrapUnwrap = errors.New("[GNOSWAP-STAKER-005] wrap, unwrap failed") + errWugnotMinimum = errors.New("[GNOSWAP-STAKER-006] can not wrapless than minimum amount") + errInvalidInput = errors.New("[GNOSWAP-STAKER-007] invalid input data") + errInvalidUnstakingFee = errors.New("[GNOSWAP-STAKER-008] invalid unstaking fee") + errAlreadyStaked = errors.New("[GNOSWAP-STAKER-009] already staked position") + errNonIncentivizedPool = errors.New("[GNOSWAP-STAKER-010] pool is not incentivized") + errOutOfRange = errors.New("[GNOSWAP-STAKER-011] out of range") + errCannotEndIncentive = errors.New("[GNOSWAP-STAKER-012] can not end incentive") + errInvalidIncentiveStartTime = errors.New("[GNOSWAP-STAKER-013] invalid incentive start time") + errInvalidIncentiveEndTime = errors.New("[GNOSWAP-STAKER-014] invalid incentive end time") + errCannotUseForExternalReward = errors.New("[GNOSWAP-STAKER-015] can not use for external reward") + errMinTier = errors.New("[GNOSWAP-STAKER-016] emission minimum tier is 1") + errDefaultPoolTier1 = errors.New("[GNOSWAP-STAKER-017] can not delete default pool tier 1") + errDefaultExternalToken = errors.New("[GNOSWAP-STAKER-018] can not delete default external token") + errInvalidPoolPath = errors.New("[GNOSWAP-STAKER-019] invalid pool path") + errInvalidPoolTier = errors.New("[GNOSWAP-STAKER-020] invalid pool tier") + errAlreadyHasTier = errors.New("[GNOSWAP-STAKER-021] pool already has emission target") + errDataNotFound = errors.New("[GNOSWAP-STAKER-022] requested data not found") + errCalculationError = errors.New("[GNOSWAP-STAKER-023] unexpected calculation error") + errZeroLiquidity = errors.New("[GNOSWAP-STAKER-024] zero liquidity") + errInvalidIncentiveDuration = errors.New("[GNOSWAP-STAKER-025] invalid incentive duration") + errNotAllowedForExternalReward = errors.New("[GNOSWAP-STAKER-026] not allowed for external reward") + errInvalidWarmUpPercent = errors.New("[GNOSWAP-STAKER-027] invalid warm-up duration") + errInvalidTickCross = errors.New("[GNOSWAP-STAKER-028] invalid tick cross") + errIncentiveAlreadyExists = errors.New("[GNOSWAP-STAKER-029] incentive already exists") + errIncentiveNotFound = errors.New("[GNOSWAP-STAKER-030] incentive not found") + errWarmUpAmountNotFound = errors.New("[GNOSWAP-STAKER-031] warm-up amount not found") + errOverflow = errors.New("[GNOSWAP-STAKER-032] overflow") +) + +func addDetailToError(err error, detail string) string { + finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) + return finalErr.Error() +} diff --git a/contract/r/gnoswap/staker/external_deposit_fee.gno b/contract/r/gnoswap/staker/external_deposit_fee.gno new file mode 100644 index 000000000..2ea3eeb58 --- /dev/null +++ b/contract/r/gnoswap/staker/external_deposit_fee.gno @@ -0,0 +1,81 @@ +package staker + +import ( + "std" + + "gno.land/r/gnoswap/v1/common" +) + +var ( + depositGnsAmount = uint64(1_000_000_000) // 1_000 GNS +) + +// GetDepositGnsAmount returns the current deposit amount in GNS. +// +// Returns: +// - uint64: The deposit amount in GNS. +func GetDepositGnsAmount() uint64 { + return depositGnsAmount +} + +// SetDepositGnsAmountByAdmin allows an admin to set the deposit amount in GNS. +// +// This function validates the caller as an admin using `common.AdminOnly`. +// If successful, it updates the deposit amount and emits an event with details +// of the change. +// +// Parameters: +// - amount (uint64): The new deposit amount in GNS. +// +// Panics: +// - If the caller is not an admin. +func SetDepositGnsAmountByAdmin(amount uint64) { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err.Error()) + } + + prevDepositGnsAmount := getDepositGnsAmount() + setDepositGnsAmount(amount) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "SetDepositGnsAmountByAdmin", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevAmount", formatUint(prevDepositGnsAmount), + "newAmount", formatUint(amount), + ) +} + +// SetDepositGnsAmount modifies the deposit gns amount +// Only governance contract can execute this function via proposal +// ref: https://docs.gnoswap.io/contracts/staker/external_deposit_fee.gno +func SetDepositGnsAmount(amount uint64) { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(err.Error()) + } + + prevDepositGnsAmount := getDepositGnsAmount() + setDepositGnsAmount(amount) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "SetDepositGnsAmount", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevAmount", formatUint(prevDepositGnsAmount), + "newAmount", formatUint(amount), + ) +} + +func setDepositGnsAmount(amount uint64) { + depositGnsAmount = amount +} + +func getDepositGnsAmount() uint64 { + return depositGnsAmount +} diff --git a/contract/r/gnoswap/staker/external_deposit_fee_test.gno b/contract/r/gnoswap/staker/external_deposit_fee_test.gno new file mode 100644 index 000000000..76ebbe0e9 --- /dev/null +++ b/contract/r/gnoswap/staker/external_deposit_fee_test.gno @@ -0,0 +1,113 @@ +package staker + +import ( + "std" + "testing" + + "gno.land/p/gnoswap/consts" +) + +func TestGetDepositGnsAmount(t *testing.T) { + expected := uint64(1_000_000_000) + actual := GetDepositGnsAmount() + + if actual != expected { + t.Errorf("GetDepositGnsAmount() = %d; want %d", actual, expected) + } +} + +func TestSetDepositGnsAmountByAdmin(t *testing.T) { + tests := []struct { + name string + caller std.Address + newAmount uint64 + shouldPanic bool + }{ + { + name: "Success - Admin sets deposit amount", + caller: consts.ADMIN, + newAmount: 2_000_000_000, + shouldPanic: false, + }, + { + name: "Failure - Non-admin tries to set deposit amount", + caller: std.Address("user1"), + newAmount: 2_000_000_000, + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetOrigCaller(tt.caller) + + defer func() { + r := recover() + if r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } else if tt.shouldPanic { + t.Errorf("Expected panic but did not occur") + } + }() + + SetDepositGnsAmountByAdmin(tt.newAmount) + + if !tt.shouldPanic { + actual := GetDepositGnsAmount() + if actual != tt.newAmount { + t.Errorf("SetDepositGnsAmountByAdmin() = %d; want %d", actual, tt.newAmount) + } + } + }) + } +} + +func TestSetDepositGnsAmount(t *testing.T) { + tests := []struct { + name string + caller std.Address + newAmount uint64 + shouldPanic bool + }{ + { + name: "Success - Governance sets deposit amount", + caller: consts.GOV_GOVERNANCE_ADDR, + newAmount: 3_000_000_000, + shouldPanic: false, + }, + { + name: "Failure - Non-governance tries to set deposit amount", + caller: std.Address("user2"), + newAmount: 3_000_000_000, + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetOrigCaller(tt.caller) + + defer func() { + r := recover() + if r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } else if tt.shouldPanic { + t.Errorf("Expected panic but did not occur") + } + }() + + SetDepositGnsAmount(tt.newAmount) + + if !tt.shouldPanic { + actual := GetDepositGnsAmount() + if actual != tt.newAmount { + t.Errorf("SetDepositGnsAmount() = %d; want %d", actual, tt.newAmount) + } + } + }) + } +} diff --git a/contract/r/gnoswap/staker/external_token_list.gno b/contract/r/gnoswap/staker/external_token_list.gno new file mode 100644 index 000000000..09c56a1cd --- /dev/null +++ b/contract/r/gnoswap/staker/external_token_list.gno @@ -0,0 +1,77 @@ +package staker + +import ( + "std" + + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +// defaultAllowed is the list of default allowed tokens to create external incentive +var defaultAllowed = []string{consts.GNS_PATH, consts.GNOT} + +// allowedTokens is a slice of all allowed token paths, including the default and added tokens. +var allowedTokens = []string{} + +func init() { + allowedTokens = defaultAllowed +} + +// AddToken adds a new token path to the list of allowed tokens +// Only the admin can add a new token. +// +// Parameters: +// - tokenPath (string): The path of the token to add +// +// Panics: +// - If the caller is not the admin +func AddToken(tokenPath string) { + caller := getPrevAddr() + if err := common.AdminOnly(caller); err != nil { + panic(err.Error()) + } + + // if exist just return + for _, t := range allowedTokens { + if t == tokenPath { + return + } + } + + allowedTokens = append(allowedTokens, tokenPath) +} + +// RemoveToken removes a token path from the list of allowed tokens. +// Only the admin can remove a token. +// +// Default tokens can not be removed. +// +// Parameters: +// - tokenPath (string): The path of the token to remove +// +// Panics: +// - If the caller is not the admin +func RemoveToken(tokenPath string) { + caller := std.PrevRealm().Addr() + if err := common.AdminOnly(caller); err != nil { + panic(err.Error()) + } + + // if default token, can not remove + isDefault := contains(defaultAllowed, tokenPath) + if isDefault { + panic(addDetailToError( + errDefaultExternalToken, + ufmt.Sprintf("external_token_list.gno__RemoveToken() || can not remove default token(%s)", tokenPath), + )) + } + + for i, t := range allowedTokens { + if t == tokenPath { + allowedTokens = append(allowedTokens[:i], allowedTokens[i+1:]...) + return + } + } +} diff --git a/contract/r/gnoswap/staker/external_token_list_test.gno b/contract/r/gnoswap/staker/external_token_list_test.gno new file mode 100644 index 000000000..80ea771e0 --- /dev/null +++ b/contract/r/gnoswap/staker/external_token_list_test.gno @@ -0,0 +1,123 @@ +package staker + +/* +func TestAddToken(t *testing.T) { + adminAddr := consts.ADMIN + nonAdminAddr := testutils.TestAddress("nonadmin") + std.TestSetOrigCaller(adminAddr) + + tests := []struct { + name string + caller std.Address + tokenPath string + expected []string + shouldPanic bool + }{ + { + name: "Admin adds a new token", + caller: adminAddr, + tokenPath: "newTokenPath", + expected: append(defaultAllowed, "newTokenPath"), + shouldPanic: false, + }, + { + name: "Non-admin tries to add a token", + caller: nonAdminAddr, + tokenPath: "unauthorizedToken", + expected: defaultAllowed, + shouldPanic: true, + }, + { + name: "Admin adds an existing token", + caller: adminAddr, + tokenPath: consts.GNS_PATH, + expected: defaultAllowed, + shouldPanic: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetOrigCaller(tt.caller) + + defer func() { + if r := recover(); r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } + }() + + AddToken(tt.tokenPath) + + if !tt.shouldPanic { + for _, token := range tt.expected { + if !contains(allowedTokens, token) { + t.Errorf("Expected token %s not found in allowedTokens", token) + } + } + } + }) + } +} + +func TestRemoveToken(t *testing.T) { + adminAddr := consts.ADMIN + nonAdminAddr := testutils.TestAddress("nonadmin") + std.TestSetOrigCaller(adminAddr) + + tests := []struct { + name string + caller std.Address + tokenPath string + expected []string + shouldPanic bool + }{ + { + name: "Admin removes a custom token", + caller: adminAddr, + tokenPath: "customToken", + expected: defaultAllowed, + shouldPanic: false, + }, + { + name: "Non-admin tries to remove a token", + caller: nonAdminAddr, + tokenPath: "unauthorizedToken", + expected: allowedTokens, + shouldPanic: true, + }, + { + name: "Admin tries to remove a default token", + caller: adminAddr, + tokenPath: consts.GNOT, + expected: allowedTokens, + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + std.TestSetOrigCaller(tt.caller) + + defer func() { + if r := recover(); r != nil { + if !tt.shouldPanic { + t.Errorf("Unexpected panic: %v", r) + } + } + }() + + RemoveToken(tt.tokenPath) + + if !tt.shouldPanic { + for _, token := range tt.expected { + if !contains(allowedTokens, token) { + t.Errorf("Expected token %s not found in allowedTokens", token) + } + } + } + }) + } +} +*/ diff --git a/contract/r/gnoswap/staker/filetests/average_block_time_change_from_gns_filetest.gno b/contract/r/gnoswap/staker/filetests/average_block_time_change_from_gns_filetest.gno new file mode 100644 index 000000000..67057774d --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/average_block_time_change_from_gns_filetest.gno @@ -0,0 +1,205 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 + +// POSITIONs: +// 1. in-range ( will be unstaked ) + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) +// - halving changes + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/testutils" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + // g1v4u8getjdeskcsmjv4shgmmjta047h6lua7mup + externalCreatorAddr = testutils.TestAddress("externalCreator") + externalCreatorUser = externalCreatorAddr + externalCreatorRealm = std.NewUserRealm(externalCreatorAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 + + poolPath = "gno.land/r/demo/wugnot:gno.land/r/gnoswap/v1/gns:3000" +) + +func main() { + testInit() + testCreatePool() + testMintWugnotGnsPos01() + testStakeTokenPos01() + testGnsAvgBlockTime() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) + + // set unstaking fee to 0 + sr.SetUnstakingFeeByAdmin(0) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testStakeTokenPos01() { + std.TestSetRealm(adminRealm) + + // stake + gnft.Approve(consts.STAKER_ADDR, tokenIdFrom(1)) + sr.StakeToken(1) + + before := gns.BalanceOf(adminUser) + + std.TestSkipHeights(100) + sr.CollectReward(1, false) + + after := gns.BalanceOf(adminUser) + + diff := after - before + if diff <= 0 { + panic("reward can not be 0 for 100 blocks") + } +} + +func testGnsAvgBlockTime() { + std.TestSetRealm(adminRealm) + + before := gns.BalanceOf(adminUser) + std.TestSkipHeights(1) + sr.CollectReward(1, false) + after := gns.BalanceOf(adminUser) + + diff := after - before // actual reward for 1 block when avg block time was 2000ms(2s) + if !isInErrorRange(3210616, diff) { + panic("expected about 3210616") + } + + gns.SetAvgBlockTimeInMsByAdmin(4000) // orig block time was 2000ms(2s), change it to 4s + sr.CollectReward(1, false) // clear reward + + before = gns.BalanceOf(adminUser) + std.TestSkipHeights(1) + sr.CollectReward(1, false) // actual reward for 1 block when avg block time was 4000ms(4s) + after = gns.BalanceOf(adminUser) + diff = after - before + if !isInErrorRange(6421232, diff) { + panic("expected about 6421232") // block time has doubled, so reward should be doubled too + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} + +func isInErrorRange(expected uint64, actual uint64) bool { + lowerBound := expected * 999999 / 1000000 + upperBound := expected * 1000001 / 1000000 + return actual >= lowerBound && actual <= upperBound +} diff --git a/contract/r/gnoswap/staker/filetests/no_position_to_give_reward_filetest.gno b/contract/r/gnoswap/staker/filetests/no_position_to_give_reward_filetest.gno new file mode 100644 index 000000000..533873367 --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/no_position_to_give_reward_filetest.gno @@ -0,0 +1,176 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 + +// POSITIONs: +// 1. in-range ( will be unstaked ) + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) +// - external bar ( gnot:gns:3000 ) + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/testutils" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + // g1v4u8getjdeskcsmjv4shgmmjta047h6lua7mup + externalCreatorAddr = testutils.TestAddress("externalCreator") + externalCreatorUser = externalCreatorAddr + externalCreatorRealm = std.NewUserRealm(externalCreatorAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 + + poolPath = "gno.land/r/demo/wugnot:gno.land/r/gnoswap/v1/gns:3000" +) + +func main() { + testInit() + testCreatePool() + testMintWugnotGnsPos01() + testStakeTokenPos01() + + testUnstakePos01() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testStakeTokenPos01() { + std.TestSetRealm(adminRealm) + + // stake + gnft.Approve(consts.STAKER_ADDR, tokenIdFrom(1)) + sr.StakeToken(1) +} + +func testUnstakePos01() { + std.TestSetRealm(adminRealm) + + oldCommunityGns := gns.BalanceOf(consts.COMMUNITY_POOL_ADDR) + + // this position-01 has been staked and staked same time + sr.UnstakeToken(1, false) + + afterCommunityGns := gns.BalanceOf(consts.COMMUNITY_POOL_ADDR) + + // therefore community pool should receive no-position-staked pool's reward + increased := afterCommunityGns - oldCommunityGns + if increased != 10702054 { // one block's reward + panic("community pool should receive no-position-staked pool's reward") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} diff --git a/contract/r/gnoswap/staker/filetests/pool_add_to_tier2_and_change_to_tier3_internal_filetest.gno b/contract/r/gnoswap/staker/filetests/pool_add_to_tier2_and_change_to_tier3_internal_filetest.gno new file mode 100644 index 000000000..4f822de24 --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/pool_add_to_tier2_and_change_to_tier3_internal_filetest.gno @@ -0,0 +1,246 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 +// 2. bar:baz:100 + +// POSITIONs: +// 1. in-range ( gnot:gns:3000 ) +// 2. in-range ( bar:baz:100 ) + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) +// - internal tier 2 ( bar:baz:100 ) -> will be changed to tier 3 + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 +) + +func main() { + testInit() + testCreatePool() + + testMintAndStakeWugnotGnsPos01() + + testSetPoolTier2() // new pool is set to tier 2 + testMintAndStakeBarBazPos02() + + testChangePoolTier3() // pool is changed to tier 3 +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) + + // set unstaking fee to 0 + sr.SetUnstakingFeeByAdmin(0) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) + pl.CreatePool( + barPath, + bazPath, + fee100, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintAndStakeWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + sr.StakeToken(1) + + std.TestSkipHeights(1) + { + // check reward for position 01 (gnot:gns:3000 tier 01) + std.TestSetRealm(adminRealm) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + if !isInErrorRange(3210615, diff) { + panic("expected about 3210615") // 10702054 * 100%(tier1 ratio) * 30%(warmUp) + } + } +} + +func testSetPoolTier2() { + std.TestSetRealm(adminRealm) + + std.TestSkipHeights(1) + sr.SetPoolTierByAdmin("gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:100", 2) +} + +func testMintAndStakeBarBazPos02() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + baz.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + bazPath, + fee100, + int32(-50), + int32(50), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(2)) + sr.StakeToken(2) + + std.TestSkipHeights(1) + { + // check reward for position 02 (bar:baz:100 tier 02) + std.TestSetRealm(adminRealm) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(2, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + if !isInErrorRange(963184, diff) { // 10702054 * 30%(tier2 ratio) * 30% (warmUp) + panic("expected about 963184") + } + } +} + +func testChangePoolTier3() { + std.TestSetRealm(adminRealm) + + sr.ChangePoolTierByAdmin("gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:100", 3) + + std.TestSkipHeights(1) + { + // check reward for position 02 (bar:baz:100 tier 02) + std.TestSetRealm(adminRealm) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(2, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + if !isInErrorRange(642122, diff) { // 10702054 * 20%(tier3 ratio) * 30% (warmUp) + panic("expected about 642122") + } + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} + +func isInErrorRange(expected uint64, actual uint64) bool { + lowerBound := expected * 999999 / 1000000 + upperBound := expected * 1000001 / 1000000 + return actual >= lowerBound && actual <= upperBound +} diff --git a/contract/r/gnoswap/staker/filetests/pool_add_to_tier2_and_removed_internal_filetest.gno b/contract/r/gnoswap/staker/filetests/pool_add_to_tier2_and_removed_internal_filetest.gno new file mode 100644 index 000000000..495d35f26 --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/pool_add_to_tier2_and_removed_internal_filetest.gno @@ -0,0 +1,266 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 +// 2. bar:baz:100 + +// POSITIONs: +// 1. in-range ( gnot:gns:3000 ) +// 2. in-range ( bar:baz:100 ) + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) +// - internal tier 2 ( bar:baz:100 ) -> will be removed from internal tiers + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 +) + +func main() { + testInit() + testCreatePool() + + testMintAndStakeWugnotGnsPos01() + + testSetPoolTier2() // new pool is set to tier 2 + testMintAndStakeBarBazPos02() + + testRemovePoolTier() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) + + // set unstaking fee to 0 + sr.SetUnstakingFeeByAdmin(0) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) + pl.CreatePool( + barPath, + bazPath, + fee100, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintAndStakeWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + sr.StakeToken(1) + + // check reward for position 01 (gnot:gns:3000 tier 01) + std.TestSetRealm(adminRealm) + std.TestSkipHeights(1) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + if !isInErrorRange(3210615, diff) { + panic("expected about 3210615") // 10702054 * 100%(tier1 ratio) * 30%(warmUp) + } +} + +func testSetPoolTier2() { + std.TestSetRealm(adminRealm) + + sr.SetPoolTierByAdmin("gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:100", 2) +} + +func testMintAndStakeBarBazPos02() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + baz.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + pn.Mint( + barPath, + bazPath, + fee100, + int32(-50), + int32(50), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(2)) + sr.StakeToken(2) + std.TestSkipHeights(1) + + { + // check reward for position 01 (gnot:gns:3000 tier 01) + std.TestSetRealm(adminRealm) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + if !isInErrorRange(2247430, diff) { + panic("expected about 2247430") // 10702054 * 70%(tier1 ratio) * 30%(warmUp) + } + } + + { + // check reward for position 02 (bar:baz:100 tier 02) + std.TestSetRealm(adminRealm) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(2, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + if !isInErrorRange(963184, diff) { + panic("expected about 963184") // 10702054 * 30%(tier2 ratio) * 30% (warmUp) + } + } +} + +func testRemovePoolTier() { + std.TestSetRealm(adminRealm) + + sr.RemovePoolTierByAdmin("gno.land/r/onbloc/bar:gno.land/r/onbloc/baz:100") + std.TestSkipHeights(1) + + { + // check reward for position 01 (gnot:gns:3000 tier 01) + std.TestSetRealm(adminRealm) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + if !isInErrorRange(3210615, diff) { + panic("expected about 3210615") // 10702054 * 100%(tier1 ratio) * 30%(warmUp) + } + } + + { + // check reward for position 02 (bar:baz:100 tier 02) + std.TestSetRealm(adminRealm) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(2, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + if diff != 0 { + panic("expected about 0") // bar:baz:100 is removed from internal tiers + } + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} + +func isInErrorRange(expected uint64, actual uint64) bool { + lowerBound := expected * 999999 / 1000000 + upperBound := expected * 1000001 / 1000000 + return actual >= lowerBound && actual <= upperBound +} diff --git a/contract/r/gnoswap/staker/filetests/position_inrange_change_by_swap_external_filetest.gno b/contract/r/gnoswap/staker/filetests/position_inrange_change_by_swap_external_filetest.gno new file mode 100644 index 000000000..a00a1ee2c --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/position_inrange_change_by_swap_external_filetest.gno @@ -0,0 +1,307 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. bar:qux:100 + +// POSITIONs: +// 1. in-range -> out-range -> in-range +// 2. (always) in-range + +// REWARDs: +// - external bar ( bar:qux:100 ) + +package staker_test + +import ( + "std" + "strconv" + "time" + + "gno.land/p/demo/grc/grc721" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + _ "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/qux" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + rr "gno.land/r/gnoswap/v1/router" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 + + poolPath = "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100" +) + +func main() { + testInit() + testCreatePool() + testMintBarQuxPos01() + testMintBarQuxPos02() + testCreateExternalIncentive() + testStakeTokenPos01AndPos02() + testMakeExternalBarStart() + + testCheckReward01() // both positions are in-range + + testMakePosition1OutRangeBySwap() + testCheckReward02() // position-01 is out-range + + testMakePosition1InRangeBySwap() // position-01 is in-range again + testCheckReward03() // both positions are in-range +} + +func testInit() { + std.TestSetRealm(adminRealm) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + barPath, + quxPath, + fee100, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintBarQuxPos01() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + quxPath, + fee100, + int32(-50), + int32(50), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testMintBarQuxPos02() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + quxPath, + fee100, + int32(-1000), + int32(1000), + "500000", + "500000", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testCreateExternalIncentive() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.STAKER_ADDR, consts.UINT64_MAX) + gns.Approve(consts.STAKER_ADDR, depositGnsAmount) + + std.TestSkipHeights(1) + sr.CreateExternalIncentive( + "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100", + barPath, + 9000000000, + 1234569600, + 1234569600+TIMESTAMP_90DAYS, + ) +} + +func testStakeTokenPos01AndPos02() { + std.TestSetRealm(adminRealm) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + gnft.Approve(stakerAddr, tokenIdFrom(2)) + + std.TestSkipHeights(1) + sr.StakeToken(1) + sr.StakeToken(2) +} + +func testMakeExternalBarStart() { + externalStartTime := int64(1234569600) + nowTime := time.Now().Unix() + timeLeft := externalStartTime - nowTime + + blockAvgTime := consts.BLOCK_GENERATION_INTERVAL + blockLeft := timeLeft / blockAvgTime + + std.TestSkipHeights(int64(blockLeft)) // skip until external bar starts + std.TestSkipHeights(10) // skip bit more to see reward calculation + +} + +func testCheckReward01() { + std.TestSetRealm(adminRealm) + + oldBar := bar.BalanceOf(adminUser) + sr.CollectReward(1, false) + newBar := bar.BalanceOf(adminUser) + diff := newBar - oldBar + + if diff == 0 { + panic("position 01 is in-range, should have reward") + } +} + +func testMakePosition1OutRangeBySwap() { + std.TestSetRealm(adminRealm) + + poolTick := pl.PoolGetSlot0Tick(poolPath) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + bar.Approve(consts.ROUTER_ADDR, consts.UINT64_MAX) + qux.Approve(consts.ROUTER_ADDR, consts.UINT64_MAX) + + tokenIn, tokenOut := rr.ExactInSwapRoute( + barPath, // inputToken + quxPath, // outputToken + "100000", // finalAmountIn + poolPath, // RouteArr + "100", // quoteArr + "0", // amountOutMin + max_timeout, // timeout + ) + std.TestSkipHeights(1) + + newPoolTick := pl.PoolGetSlot0Tick(poolPath) + println("oldPoolTick", poolTick) + println("newPoolTick", newPoolTick) + println() +} + +func testCheckReward02() { + std.TestSetRealm(adminRealm) + + oldBar := bar.BalanceOf(adminUser) + sr.CollectReward(1, false) + newBar := bar.BalanceOf(adminUser) + diff := newBar - oldBar + + if diff != 0 { + panic("position 01 is out-range, should not have reward") + } +} + +func testMakePosition1InRangeBySwap() { + std.TestSetRealm(adminRealm) + + poolTick := pl.PoolGetSlot0Tick(poolPath) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + bar.Approve(consts.ROUTER_ADDR, consts.UINT64_MAX) + qux.Approve(consts.ROUTER_ADDR, consts.UINT64_MAX) + + tokenIn, tokenOut := rr.ExactInSwapRoute( + quxPath, + barPath, + "100000", + "gno.land/r/onbloc/qux:gno.land/r/onbloc/bar:100", + "100", + "0", + max_timeout, + ) + + newPoolTick := pl.PoolGetSlot0Tick(poolPath) + println("oldPoolTick", poolTick) + println("newPoolTick", newPoolTick) + println() +} + +func testCheckReward03() { + std.TestSkipHeights(1) + + std.TestSetRealm(adminRealm) + + oldBar := bar.BalanceOf(adminUser) + sr.CollectReward(1, false) + newBar := bar.BalanceOf(adminUser) + diff := newBar - oldBar + + if diff <= 0 { + panic("position 01 is in-range, should have reward") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} diff --git a/contract/r/gnoswap/staker/filetests/position_inrange_change_by_swap_internal_filetest.gno b/contract/r/gnoswap/staker/filetests/position_inrange_change_by_swap_internal_filetest.gno new file mode 100644 index 000000000..89c206044 --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/position_inrange_change_by_swap_internal_filetest.gno @@ -0,0 +1,284 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 + +// POSITIONs: +// 1. in-range -> out-range -> in-range +// 2. (always) in-range + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + rr "gno.land/r/gnoswap/v1/router" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 + + poolPath = "gno.land/r/demo/wugnot:gno.land/r/gnoswap/v1/gns:3000" +) + +func main() { + testInit() + testCreatePool() + testMintWugnotGnsPos01() + testMintWugnotGnsPos02() + + testStakeTokenPos01AndPos02() + + testCheckReward01() // both positions are in-range + + testMakePosition1OutRangeBySwap() + testCheckReward02() // position-01 is out-range + + testMakePosition1InRangeBySwap() // position-01 is in-range again + testCheckReward03() // both positions are in-range +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testMintWugnotGnsPos02() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-1020), + int32(1020), + "500000", + "500000", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testStakeTokenPos01AndPos02() { + std.TestSetRealm(adminRealm) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + gnft.Approve(stakerAddr, tokenIdFrom(2)) + + std.TestSkipHeights(1) + sr.StakeToken(1) + sr.StakeToken(2) +} + +func testCheckReward01() { + std.TestSkipHeights(1) + std.TestSetRealm(adminRealm) + + oldGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + newGns := gns.BalanceOf(adminUser) + diff := newGns - oldGns + + if diff == 0 { + panic("position 01 is in-range, should have reward") + } +} + +func testMakePosition1OutRangeBySwap() { + std.TestSetRealm(adminRealm) + + poolTick := pl.PoolGetSlot0Tick(poolPath) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + wugnot.Approve(consts.ROUTER_ADDR, consts.UINT64_MAX) + gns.Approve(consts.ROUTER_ADDR, consts.UINT64_MAX) + + tokenIn, tokenOut := rr.ExactInSwapRoute( + wugnotPath, // inputToken + gnsPath, // outputToken + "100000", // finalAmountIn + poolPath, // RouteArr + "100", // quoteArr + "0", // amountOutMin + max_timeout, + ) + + newPoolTick := pl.PoolGetSlot0Tick(poolPath) + println("oldPoolTick", poolTick) + println("newPoolTick", newPoolTick) + println() +} + +func testCheckReward02() { + std.TestSkipHeights(1) + std.TestSetRealm(adminRealm) + + oldGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + newGns := gns.BalanceOf(adminUser) + diff := newGns - oldGns + + if diff != 0 { + panic("position 01 is out-range, should not have reward") + } +} + +func testMakePosition1InRangeBySwap() { + std.TestSetRealm(adminRealm) + + poolTick := pl.PoolGetSlot0Tick(poolPath) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + wugnot.Approve(consts.ROUTER_ADDR, consts.UINT64_MAX) + gns.Approve(consts.ROUTER_ADDR, consts.UINT64_MAX) + + tokenIn, tokenOut := rr.ExactInSwapRoute( + gnsPath, + wugnotPath, + "100000", + "gno.land/r/gnoswap/v1/gns:gno.land/r/demo/wugnot:3000", + "100", + "0", + max_timeout, + ) + + newPoolTick := pl.PoolGetSlot0Tick(poolPath) + println("oldPoolTick", poolTick) + println("newPoolTick", newPoolTick) + println() +} + +func testCheckReward03() { + std.TestSkipHeights(1) + std.TestSetRealm(adminRealm) + + oldGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + newGns := gns.BalanceOf(adminUser) + diff := newGns - oldGns + + if diff == 0 { + panic("position 01 is in-range, should have reward") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} diff --git a/contract/r/gnoswap/staker/filetests/reward_for_user_collect_change_by_collecting_reward_external_filetest.gno b/contract/r/gnoswap/staker/filetests/reward_for_user_collect_change_by_collecting_reward_external_filetest.gno new file mode 100644 index 000000000..bb0856dde --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/reward_for_user_collect_change_by_collecting_reward_external_filetest.gno @@ -0,0 +1,261 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 + +// POSITIONs: +// 1. in-range +// 2. in-range + +// REWARDs: +// - external bar ( bar:qux:100 ) +// - external qux ( qux:gns:100 ) + +package staker_test + +import ( + "std" + "strconv" + "time" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/json" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/qux" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 +) + +func main() { + testInit() + testCreatePool() + + testCreateExternalIncentiveBar() + testCreateExternalIncentiveQux() + + testMintAndStakeBarQuxPos01() + testMintAndStakeBarQuxPos02() + + testMakeExternalBarAndQuxStart() + + testCollectRewardPos01() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + barPath, + quxPath, + fee100, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testCreateExternalIncentiveBar() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.STAKER_ADDR, consts.UINT64_MAX) + gns.Approve(consts.STAKER_ADDR, depositGnsAmount) + + std.TestSkipHeights(1) + sr.CreateExternalIncentive( + "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100", + barPath, + 9000000000, + 1234569600, + 1234569600+TIMESTAMP_90DAYS, + ) +} + +func testCreateExternalIncentiveQux() { + std.TestSetRealm(adminRealm) + + qux.Approve(consts.STAKER_ADDR, consts.UINT64_MAX) + gns.Approve(consts.STAKER_ADDR, depositGnsAmount) + + std.TestSkipHeights(1) + sr.CreateExternalIncentive( + "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100", + quxPath, + 18000000000, + 1234569600, + 1234569600+TIMESTAMP_90DAYS, + ) +} + +func testMintAndStakeBarQuxPos01() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + quxPath, + fee100, + int32(-50), + int32(50), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + sr.StakeToken(1) +} + +func testMintAndStakeBarQuxPos02() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + quxPath, + fee100, + int32(-50), + int32(50), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(2)) + sr.StakeToken(2) +} + +func testMakeExternalBarAndQuxStart() { + externalStartTime := int64(1234569600) + nowTime := time.Now().Unix() + timeLeft := externalStartTime - nowTime + + blockAvgTime := consts.BLOCK_GENERATION_INTERVAL + blockLeft := timeLeft / blockAvgTime + + std.TestSkipHeights(int64(blockLeft)) // skip until external bar starts + std.TestSkipHeights(10) // skip bit more to see reward calculation + + if extractReward(1) != extractReward(2) { + panic("reward is not equal") + } +} + +func testCollectRewardPos01() { + std.TestSetRealm(adminRealm) + + std.TestSkipHeights(10) + sr.CollectReward(1, false) + + reward01 := extractReward(1) + if reward01 != 0 { + panic("reward collected, should be 0 for checking") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} + +func extractReward(tokenId uint64) uint64 { + apiReward := sr.ApiGetRewardsByLpTokenId(tokenId) + rawReward, _ := json.Unmarshal([]byte(apiReward)) + rawRewardObject, _ := rawReward.GetKey("response") + arrReward, _ := rawRewardObject.GetArray() + + reward, _ := arrReward[0].GetKey("rewards") + rewardArr, _ := reward.GetArray() + if len(rewardArr) == 0 { + return 0 + } + rewardTokenAmount, _ := rewardArr[0].GetKey("rewardTokenAmount") + + rewardTokenAmountInt, _ := strconv.ParseUint(rewardTokenAmount.String(), 10, 64) + return rewardTokenAmountInt +} diff --git a/contract/r/gnoswap/staker/filetests/reward_for_user_collect_change_by_collecting_reward_internal_filetest.gno b/contract/r/gnoswap/staker/filetests/reward_for_user_collect_change_by_collecting_reward_internal_filetest.gno new file mode 100644 index 000000000..4b93e19e3 --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/reward_for_user_collect_change_by_collecting_reward_internal_filetest.gno @@ -0,0 +1,214 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 + +// POSITIONs: +// 1. in-range +// 2. in-range (will be unstaked) + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/json" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 +) + +func main() { + testInit() + testCreatePool() + + testMintAndStakeWugnotGnsPos01() // position-01 is in-range + testMintAndStakeWugnotGnsPos02() // position-02 is in-range + + testCollectRewardPos01() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // short warm-up period + sr.SetWarmUp(100, 901) + sr.SetWarmUp(70, 301) + sr.SetWarmUp(50, 151) + sr.SetWarmUp(30, 1) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintAndStakeWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + sr.StakeToken(1) +} + +func testMintAndStakeWugnotGnsPos02() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(2)) + sr.StakeToken(2) +} + +func testCollectRewardPos01() { + std.TestSetRealm(adminRealm) + std.TestSkipHeights(1) + + apiRewardBefore := extractReward(1) + sr.CollectReward(1, false) + apiRewardAfter := extractReward(1) + + if apiRewardBefore == apiRewardAfter { + panic("can not be same") + } + + if apiRewardAfter != 0 { + panic("reward should be 0") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} + +func extractReward(tokenId uint64) uint64 { + apiReward := sr.ApiGetRewardsByLpTokenId(tokenId) + rawReward, _ := json.Unmarshal([]byte(apiReward)) + rawRewardObject, _ := rawReward.GetKey("response") + arrReward, _ := rawRewardObject.GetArray() + + reward, _ := arrReward[0].GetKey("rewards") + rewardArr, _ := reward.GetArray() + if len(rewardArr) == 0 { + return 0 + } + rewardTokenAmount, _ := rewardArr[0].GetKey("rewardTokenAmount") + + rewardTokenAmountInt, _ := strconv.ParseUint(rewardTokenAmount.String(), 10, 64) + return rewardTokenAmountInt +} diff --git a/contract/r/gnoswap/staker/filetests/single_gns_external_ends_filetest.gno b/contract/r/gnoswap/staker/filetests/single_gns_external_ends_filetest.gno new file mode 100644 index 000000000..bb4294f9e --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/single_gns_external_ends_filetest.gno @@ -0,0 +1,272 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. bar:qux:100 + +// POSITIONs: +// 1. in-range + +// REWARDs: +// - external gns 90 days ( bar:qux:100 ) + +package staker_test + +import ( + "std" + "strconv" + "time" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/testutils" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/qux" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + // g1v4u8getjdeskcsmjv4shgmmjta047h6lua7mup + externalCreatorAddr = testutils.TestAddress("externalCreator") + externalCreatorUser = externalCreatorAddr + externalCreatorRealm = std.NewUserRealm(externalCreatorAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 + + poolPath = "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100" +) + +func main() { + testInit() + testCreatePool() + testMintBarQuxPos01() + testMintBarQuxPos02() + + testCreateExternalIncentiveGns() + testMakeExternalStart() + + testStakeTokenPos01() + testStakeTokenPos02() + + testCollectRewardPos01AndPos02() + + testEndExternalGns() + testCollectRewardPos01AndPos02AfterEnd() +} + +func testInit() { + std.TestSetRealm(adminRealm) + sr.SetUnstakingFeeByAdmin(0) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + barPath, + quxPath, + fee100, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintBarQuxPos01() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + quxPath, + fee100, + int32(-50), + int32(50), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testMintBarQuxPos02() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + quxPath, + fee100, + int32(-50), + int32(50), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testCreateExternalIncentiveGns() { + std.TestSetRealm(adminRealm) + gns.Transfer(externalCreatorUser, 9000000000) + gns.Transfer(externalCreatorUser, depositGnsAmount) + + std.TestSetRealm(externalCreatorRealm) // creator + gns.Approve(consts.STAKER_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + sr.CreateExternalIncentive( + "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100", + gnsPath, + 9000000000, + 1234569600, + 1234569600+TIMESTAMP_90DAYS, + ) +} + +func testMakeExternalStart() { + externalStartTime := int64(1234569600) + nowTime := time.Now().Unix() + timeLeft := externalStartTime - nowTime + + blockAvgTime := consts.BLOCK_GENERATION_INTERVAL + blockLeft := timeLeft / blockAvgTime + + std.TestSkipHeights(int64(blockLeft)) // skip until external bar starts + std.TestSkipHeights(10) // skip bit more to see reward calculation +} + +func testStakeTokenPos01() { + std.TestSetRealm(adminRealm) + + std.TestSkipHeights(1) + gnft.Approve(stakerAddr, tokenIdFrom(1)) + sr.StakeToken(1) +} + +func testStakeTokenPos02() { + std.TestSetRealm(adminRealm) + + std.TestSkipHeights(1) + gnft.Approve(stakerAddr, tokenIdFrom(2)) + sr.StakeToken(2) +} + +func testCollectRewardPos01AndPos02() { + std.TestSetRealm(adminRealm) + + std.TestSkipHeights(1) + sr.CollectReward(1, false) + sr.CollectReward(2, false) +} + +func testEndExternalGns() { + externalEndTime := (1234569600 + TIMESTAMP_90DAYS) + nowTime := time.Now().Unix() + timeLeft := externalEndTime - nowTime + + blockAvgTime := consts.BLOCK_GENERATION_INTERVAL + blockLeft := timeLeft / blockAvgTime + + std.TestSkipHeights(int64(blockLeft)) // skip until external gns ends + std.TestSkipHeights(10) // skip bit more to see reward calculation + + std.TestSetRealm(externalCreatorRealm) + gnsBalanceBeforeEnds := gns.BalanceOf(externalCreatorUser) + sr.EndExternalIncentive( + externalCreatorAddr, + "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100", + gnsPath, + 1234569600, + 1234569600+TIMESTAMP_90DAYS, + 126, + ) + gnsBalanceAfterEnds := gns.BalanceOf(externalCreatorUser) + if gnsBalanceAfterEnds != 1000028693 { + // 1000000000 = depositGnsAmount + // 28693 = penalty + panic("expected 1000028693") + } +} + +func testCollectRewardPos01AndPos02AfterEnd() { + std.TestSetRealm(adminRealm) + std.TestSkipHeights(1) + + before := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + sr.CollectReward(2, false) + after := gns.BalanceOf(adminUser) + diff := after - before + if diff == 0 { + panic("reward can not be zero") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} diff --git a/contract/r/gnoswap/staker/filetests/single_position_stake_unstake_restake_filetest.gno b/contract/r/gnoswap/staker/filetests/single_position_stake_unstake_restake_filetest.gno new file mode 100644 index 000000000..3e0ed2102 --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/single_position_stake_unstake_restake_filetest.gno @@ -0,0 +1,222 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 + +// POSITIONs: +// 1. in-range ( stake -> unstake -> restake) + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/testutils" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + // g1v4u8getjdeskcsmjv4shgmmjta047h6lua7mup + externalCreatorAddr = testutils.TestAddress("externalCreator") + externalCreatorUser = externalCreatorAddr + externalCreatorRealm = std.NewUserRealm(externalCreatorAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 + + poolPath = "gno.land/r/demo/wugnot:gno.land/r/gnoswap/v1/gns:3000" +) + +func main() { + testInit() + testCreatePool() + testMintWugnotGnsPos01() + testMintWugnotGnsPos02() + + testStakeTokenPos01AndPos02() + + testUnstakeTokenPos01() + + testStakeTokenPos01Again() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testMintWugnotGnsPos02() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-1020), + int32(1020), + "500000", + "500000", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testStakeTokenPos01AndPos02() { + std.TestSetRealm(adminRealm) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + gnft.Approve(stakerAddr, tokenIdFrom(2)) + + std.TestSkipHeights(1) + sr.StakeToken(1) + sr.StakeToken(2) +} + +func testUnstakeTokenPos01() { + std.TestSetRealm(adminRealm) + + std.TestSkipHeights(1) + + beforeGns := gns.BalanceOf(adminUser) + sr.UnstakeToken(1, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + + if diff == 0 { + panic("position 01 was in-range, should have reward") + } +} + +func testStakeTokenPos01Again() { + std.TestSetRealm(adminRealm) + + std.TestSkipHeights(1) + gnft.Approve(stakerAddr, tokenIdFrom(1)) + sr.StakeToken(1) + + // reward check for staked > unstaked > staked position + std.TestSkipHeights(1) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + + if diff == 0 { + panic("position 01 in-range, should have reward") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} diff --git a/contract/r/gnoswap/staker/filetests/single_position_stake_unstake_same_block_filetest.gno b/contract/r/gnoswap/staker/filetests/single_position_stake_unstake_same_block_filetest.gno new file mode 100644 index 000000000..392dee663 --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/single_position_stake_unstake_same_block_filetest.gno @@ -0,0 +1,166 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 + +// POSITIONs: +// 1. in-range ( stake -> unstake in same block) + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/testutils" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + // g1v4u8getjdeskcsmjv4shgmmjta047h6lua7mup + externalCreatorAddr = testutils.TestAddress("externalCreator") + externalCreatorUser = externalCreatorAddr + externalCreatorRealm = std.NewUserRealm(externalCreatorAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 + + poolPath = "gno.land/r/demo/wugnot:gno.land/r/gnoswap/v1/gns:3000" +) + +func main() { + testInit() + testCreatePool() + testMintWugnotGnsPos01() + testStakeAndUnstakeTokenPos01() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testStakeAndUnstakeTokenPos01() { + std.TestSetRealm(adminRealm) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + + beforeGns := gns.BalanceOf(adminUser) + std.TestSkipHeights(1) + sr.StakeToken(1) + sr.UnstakeToken(1, false) + + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + + if diff != 0 { + panic("position 01 is staked and unstaked in same block, should not have reward") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} diff --git a/contract/r/gnoswap/staker/filetests/staked_liquidity_change_by_staking_unstaking_external_filetest.gno b/contract/r/gnoswap/staker/filetests/staked_liquidity_change_by_staking_unstaking_external_filetest.gno new file mode 100644 index 000000000..63fdfaa10 --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/staked_liquidity_change_by_staking_unstaking_external_filetest.gno @@ -0,0 +1,262 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. bar:qux:100 + +// POSITIONs: +// 1. in-range +// 2. (will be unstaked) in-range + +// REWARDs: +// - external bar ( bar:qux:100 ) + +package staker_test + +import ( + "std" + "strconv" + "time" + + "gno.land/p/demo/grc/grc721" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + _ "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/qux" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 + + poolPath = "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100" +) + +func main() { + testInit() + testCreatePool() + testMintBarQuxPos01() + testMintBarQuxPos02() + + testCreateExternalIncentive() + + testStakeTokenPos01() + testStakeTokenPos02() + testMakeExternalBarStart() + + testCollectReward() + testUnstakeTokenPos02() + testCollectRewardAfterUnstake() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + sr.SetUnstakingFeeByAdmin(0) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + barPath, + quxPath, + fee100, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintBarQuxPos01() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + quxPath, + fee100, + int32(-50), + int32(50), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testMintBarQuxPos02() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + qux.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + barPath, + quxPath, + fee100, + int32(-50), + int32(50), + "50", + "50", + "1", + "1", + max_timeout, + adminAddr, + adminAddr, + ) +} + +func testCreateExternalIncentive() { + std.TestSetRealm(adminRealm) + + bar.Approve(consts.STAKER_ADDR, consts.UINT64_MAX) + gns.Approve(consts.STAKER_ADDR, depositGnsAmount) + + std.TestSkipHeights(1) + sr.CreateExternalIncentive( + "gno.land/r/onbloc/bar:gno.land/r/onbloc/qux:100", + barPath, + 9000000000, + 1234569600, + 1234569600+TIMESTAMP_90DAYS, + ) +} + +func testStakeTokenPos01() { + std.TestSetRealm(adminRealm) + + std.TestSkipHeights(1) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + sr.StakeToken(1) +} + +func testStakeTokenPos02() { + std.TestSetRealm(adminRealm) + + gnft.Approve(stakerAddr, tokenIdFrom(2)) + sr.StakeToken(2) +} + +func testMakeExternalBarStart() { + externalStartTime := int64(1234569600) + nowTime := time.Now().Unix() + + timeLeft := externalStartTime - nowTime + + blockAvgTime := consts.BLOCK_GENERATION_INTERVAL + blockLeft := timeLeft / blockAvgTime + std.TestSkipHeights(int64(blockLeft)) // skip until external bar starts + std.TestSkipHeights(10) // skip bit more to see reward calculation + + // check reward for position 01 (in-range) + std.TestSetRealm(adminRealm) + beforeBar := bar.BalanceOf(adminUser) + sr.CollectReward(1, false) + afterBar := bar.BalanceOf(adminUser) + diff := afterBar - beforeBar + if diff == 0 { + panic("position 01 in-range, should have reward") + } +} + +func testCollectReward() { + std.TestSetRealm(adminRealm) + + oldBar := bar.BalanceOf(adminUser) + std.TestSkipHeights(1) + sr.CollectReward(1, false) + newBar := bar.BalanceOf(adminUser) + diff := newBar - oldBar + if !isInErrorRange(diff, 347) { // 2314(rewardPerBlock) * 50%(position01Ratio) * 30%(warmUp) + panic("expected about 347") + } +} + +func testUnstakeTokenPos02() { + std.TestSetRealm(adminRealm) + sr.UnstakeToken(2, false) // position 02 is unstaked, position 01 the only staked position +} + +func testCollectRewardAfterUnstake() { + std.TestSetRealm(adminRealm) + + oldBar := bar.BalanceOf(adminUser) + std.TestSkipHeights(1) + sr.CollectReward(1, false) + newBar := bar.BalanceOf(adminUser) + diff := newBar - oldBar + if !isInErrorRange(diff, 693) { // 2314(rewardPerBlock) * 30%(warmUp) + panic("expected about 693") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} + +func isInErrorRange(expected uint64, actual uint64) bool { + lowerBound := expected * 999999 / 1000000 + upperBound := expected * 1000001 / 1000000 + return actual >= lowerBound && actual <= upperBound +} diff --git a/contract/r/gnoswap/staker/filetests/staked_liquidity_change_by_staking_unstaking_internal_filetest.gno b/contract/r/gnoswap/staker/filetests/staked_liquidity_change_by_staking_unstaking_internal_filetest.gno new file mode 100644 index 000000000..877279b0e --- /dev/null +++ b/contract/r/gnoswap/staker/filetests/staked_liquidity_change_by_staking_unstaking_internal_filetest.gno @@ -0,0 +1,222 @@ +// PKGPATH: gno.land/r/gnoswap/v1/staker_test + +// POOLs: +// 1. gnot:gns:3000 + +// POSITIONs: +// 1. in-range +// 2. (will be untaked) in-range + +// REWARDs: +// - internal tier 1 ( gnot:gns:3000 ) + +package staker_test + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/demo/wugnot" + "gno.land/r/gnoswap/v1/gns" + + "gno.land/r/gnoswap/v1/gnft" + + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + sr "gno.land/r/gnoswap/v1/staker" +) + +var ( + adminAddr = consts.ADMIN + adminUser = adminAddr + adminRealm = std.NewUserRealm(adminAddr) + + stakerAddr = consts.STAKER_ADDR + stakerUser = stakerAddr + stakerRealm = std.NewCodeRealm(consts.STAKER_PATH) + + wugnotAddr = consts.WUGNOT_ADDR + + fooPath = "gno.land/r/onbloc/foo" + barPath = "gno.land/r/onbloc/bar" + bazPath = "gno.land/r/onbloc/baz" + quxPath = "gno.land/r/onbloc/qux" + oblPath = "gno.land/r/onbloc/obl" + + gnsPath = "gno.land/r/gnoswap/v1/gns" + wugnotPath = "gno.land/r/demo/wugnot" + + fee100 uint32 = 100 + fee500 uint32 = 500 + fee3000 uint32 = 3000 + + max_timeout int64 = 9999999999 + + // external incentive deposit fee + depositGnsAmount uint64 = 1_000_000_000 // 1_000 GNS + + TIMESTAMP_90DAYS int64 = 90 * 24 * 60 * 60 + TIMESTAMP_180DAYS int64 = 180 * 24 * 60 * 60 + TIMESTAMP_365DAYS int64 = 365 * 24 * 60 * 60 +) + +func main() { + testInit() + testCreatePool() + + testMintAndStakeWugnotGnsPos01() + testMintAndStakeWugnotGnsPos02() + + testCollectReward() + testUnstakeTokenPos02() + testCollectRewardAfterUnstake() +} + +func testInit() { + std.TestSetRealm(adminRealm) + + // prepare wugnot + std.TestIssueCoins(adminAddr, std.Coins{{"ugnot", 100_000_000_000_000}}) + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(adminAddr, wugnotAddr, std.Coins{{"ugnot", 50_000_000_000_000}}) + std.TestSetOrigSend(std.Coins{{"ugnot", 50_000_000_000_000}}, nil) + wugnot.Deposit() + std.TestSetOrigSend(nil, nil) + + sr.SetUnstakingFeeByAdmin(0) +} + +func testCreatePool() { + std.TestSetRealm(adminRealm) + + pl.SetPoolCreationFeeByAdmin(0) + + std.TestSkipHeights(1) + pl.CreatePool( + wugnotPath, + gnsPath, + fee3000, + common.TickMathGetSqrtRatioAtTick(0).ToString(), // 79228162514264337593543950337 + ) +} + +func testMintAndStakeWugnotGnsPos01() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(1)) + sr.StakeToken(1) +} + +func testMintAndStakeWugnotGnsPos02() { + std.TestSetRealm(adminRealm) + + wugnot.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + gns.Approve(consts.POOL_ADDR, consts.UINT64_MAX) + + std.TestSkipHeights(1) + pn.Mint( + wugnotPath, + gnsPath, + fee3000, + int32(-60), + int32(60), + "100", + "100", + "0", + "0", + max_timeout, + adminAddr, + adminAddr, + ) + + gnft.Approve(stakerAddr, tokenIdFrom(2)) + sr.StakeToken(2) + std.TestSkipHeights(1) +} + +func testCollectReward() { + std.TestSetRealm(adminRealm) + + // clear reward + sr.CollectReward(1, false) + sr.CollectReward(2, false) + + std.TestSkipHeights(1) + beforeGns := gns.BalanceOf(adminUser) + sr.CollectReward(1, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + + if !isInErrorRange(diff, 1605307) { // 10702054 * 50% * 30% + panic("expected about 1605307") + } +} + +func testUnstakeTokenPos02() { + std.TestSetRealm(adminRealm) + + sr.UnstakeToken(2, false) +} + +func testCollectRewardAfterUnstake() { + std.TestSetRealm(adminRealm) + + beforeGns := gns.BalanceOf(adminUser) + std.TestSkipHeights(1) + sr.CollectReward(1, false) + afterGns := gns.BalanceOf(adminUser) + diff := afterGns - beforeGns + + if !isInErrorRange(diff, 3210615) { // 10702054 * 30% + panic("expected about 3210615") + } +} + +func tokenIdFrom(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic("tokenId is nil") + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic("unsupported tokenId type") + } +} + +func isInErrorRange(expected uint64, actual uint64) bool { + lowerBound := expected * 999999 / 1000000 + upperBound := expected * 1000001 / 1000000 + return actual >= lowerBound && actual <= upperBound +} diff --git a/contract/r/gnoswap/staker/getter.gno b/contract/r/gnoswap/staker/getter.gno new file mode 100644 index 000000000..1bfbf58f2 --- /dev/null +++ b/contract/r/gnoswap/staker/getter.gno @@ -0,0 +1,716 @@ +package staker + +import ( + "std" + "strconv" + "time" + + "gno.land/p/demo/json" + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/gns" +) + +// GetPoolByPoolPath retrieves the pool's path using the provided pool path. +// +// Parameters: +// - poolPath (string): The path of the pool to retrieve. +// +// Returns: +// - string: The `poolPath` of the found pool. +// +// Panics: +// - If the pool corresponding to the given `poolPath` does not exist in the `pools` collection. +func GetPoolByPoolPath(poolPath string) *Pool { + pool, ok := pools.Get(poolPath) + if !ok { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("poolPath(%s) pool does not exist", poolPath)), + ) + } + + return pool +} + +// GetPoolIncentiveIdList returns the list of incentive IDs for a given pool +// +// Parameters: +// - poolPath (string): The path of the pool to get incentives for +// +// Returns: +// - A slice og incentive IDs associated with the pool +// +// Panics: +// - If the pool incentives do not exist for the given pool path +func GetPoolIncentiveIdList(poolPath string) []string { + pool := GetPoolByPoolPath(poolPath) + + ids := []string{} + pool.incentives.byTime.Iterate("", "", func(key string, value interface{}) bool { + ids = append(ids, key) + return true + }) + + return ids +} + +// GetIncentive retrieves an external incentive by its ID from a specified pool. +// +// This function searches for an incentive within the specified pool's incentives, +// identified by the given `incentiveId`. If the pool or the incentive does not exist, +// the function panics with an appropriate error message. +// +// Parameters: +// - poolPath (string): The path of the pool containing the incentive. +// - incentiveId (string): The ID of the incentive to retrieve. +// +// Returns: +// - *ExternalIncentive: A pointer to the retrieved ExternalIncentive. +// +// Panics: +// - If the pool corresponding to the given `poolPath` does not exist. +// - If the incentive with the specified `incentiveId` does not exist. +func GetIncentive(poolPath string, incentiveId string) *ExternalIncentive { + pool := GetPoolByPoolPath(poolPath) + + incentive, exist := pool.incentives.byTime.Get(incentiveId) + if !exist { + panic(ufmt.Sprintf("incentiveId(%s) incentive does not exist", incentiveId)) + } + + return incentive.(*ExternalIncentive) +} + +// GetIncentiveStartTimestamp returns the start timestamp for a given incentive +// +// Parameters: +// - incentiveId (string): The ID of the incentive +// +// Returns: +// - int64: The start timestamp of the incentive +// +// Panics: +// - If the incentive does not exist for the given incentiveId +func GetIncentiveStartTimestamp(poolPath string, incentiveId string) int64 { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.startTimestamp +} + +// GetIncentiveEndTimestamp returns the end timestamp for a given incentive +// +// Parameters: +// - incentiveId (string): The ID of the incentive +// +// Returns: +// - int64: The end timestamp of the incentive +// +// Panics: +// - If the incentive does not exist for the given incentiveId +func GetIncentiveEndTimestamp(poolPath string, incentiveId string) int64 { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.endTimestamp +} + +// GetTargetPoolPathByIncentiveId returns the target pool path for a given incentive +// +// Parameters: +// - incentiveId (string): The ID of the incentive +// +// Returns: +// - The target pool path (string) associated with the incentive +// +// Panics: +// - If the incentive does nor exist for the given incentive ID +func GetTargetPoolPathByIncentiveId(poolPath string, incentiveId string) string { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.targetPoolPath +} + +// GetCreatedHeightOfIncentive retrieves the creation height of a specified incentive. +// +// Parameters: +// - poolPath (string): The path of the pool containing the incentive. +// - incentiveId (string): The ID of the incentive to retrieve the creation height for. +// +// Returns: +// - int64: The creation height of the specified incentive. +func GetCreatedHeightOfIncentive(poolPath string, incentiveId string) int64 { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.createdHeight +} + +// GetIncentiveRewardToken returns the reward token for a given incentive +// +// Parameters: +// - incentiveId (string): The ID of the incentive +// +// Returns: +// - The reward token (string) associated with the incentive +// +// Panics: +// - If the incentive does not exist for the given incentiveId +func GetIncentiveRewardToken(poolPath string, incentiveId string) string { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.rewardToken +} + +// GetIncentiveRewardAmount returns the reward amount for a given incentive as a Uint256 +// +// Parameters: +// - incentiveId (string): The ID of the incentive +// +// Returns: +// - *u256.Uint: The reward amount associated with the incentive +// +// Panics: +// - If the incentive does not exist for the given incentiveId +func GetIncentiveRewardAmount(poolPath string, incentiveId string) *u256.Uint { + incentive := GetIncentive(poolPath, incentiveId) + + return u256.NewUint(incentive.rewardAmount) +} + +// GetIncentiveRewardAmountAsString returns the reward amount for a given incentive as a string +// +// Parameters: +// - incentiveId (string): The ID of the incentive +// +// Returns: +// - string: The reward amount associated with the incentive as a string +// +// Panics: +// - If the incentive does not exist for the given incentiveId +func GetIncentiveRewardAmountAsString(poolPath string, incentiveId string) string { + rewardAmount := GetIncentiveRewardAmount(incentiveId, poolPath) + + return rewardAmount.ToString() +} + +// GetIncentiveStartHeight retrieves the start height of a specified incentive. +// +// This function looks up an incentive within the specified pool using the `poolPath` and `incentiveId`. +// It then returns the height at which the incentive starts. +// +// Parameters: +// - poolPath (string): The path of the pool containing the incentive. +// - incentiveId (string): The ID of the incentive to retrieve the start height for. +// +// Returns: +// - int64: The starting height of the specified incentive. +func GetIncentiveStartHeight(poolPath string, incentiveId string) int64 { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.startHeight +} + +// GetIncentiveEndHeight retrieves the end height of a specified incentive. +// +// This function looks up an incentive within the specified pool using the `poolPath` and `incentiveId`. +// It then returns the height at which the incentive ends. +// +// Parameters: +// - poolPath (string): The path of the pool containing the incentive. +// - incentiveId (string): The ID of the incentive to retrieve the end height for. +// +// Returns: +// - int64: The ending height of the specified incentive. +func GetIncentiveEndHeight(poolPath string, incentiveId string) int64 { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.endHeight +} + +// GetIncentiveRewardPerBlock retrieves the reward per block of a specified incentive. +// +// This function looks up an incentive within the specified pool using the `poolPath` and `incentiveId`. +// It then returns the reward amount distributed per block for the incentive. +// +// Parameters: +// - poolPath (string): The path of the pool containing the incentive. +// - incentiveId (string): The ID of the incentive to retrieve the reward per block for. +// +// Returns: +// - uint64: The reward amount distributed per block for the specified incentive. +func GetIncentiveRewardPerBlock(poolPath string, incentiveId string) uint64 { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.rewardPerBlock +} + +// GetIncentiveRefundee returns the refundee address for a given incentive +// +// Parameters: +// - incentiveId (string): The ID of the incentive +// +// Returns: +// - std.Address: The refundee address of the incentive +// +// Panics: +// - If the incentive does not exist for the given incentiveId +func GetIncentiveRefundee(poolPath string, incentiveId string) std.Address { + incentive := GetIncentive(poolPath, incentiveId) + + return incentive.refundee +} + +// GetDeposit retrieves a deposit associated with a given LP token ID. +// +// This function looks up the deposit record for the specified LP token ID in the global `deposits` collection. +// If the deposit does not exist, it panics with an appropriate error message. +// +// Parameters: +// - lpTokenId (uint64): The unique identifier of the LP token for which the deposit is retrieved. +// +// Returns: +// - *Deposit: A pointer to the `Deposit` object associated with the given LP token ID. +func GetDeposit(lpTokenId uint64) *Deposit { + deposit := deposits.Get(lpTokenId) + if deposit == nil { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("lpTokenId(%d) deposit does not exist", lpTokenId)), + ) + } + + return deposit +} + +// GetDepositOwner returns the owner address of a deposit for a given LP token ID +// +// Parameters: +// - lpTokenId (uint64): The ID of the LP token +// +// Returns: +// - std.Address: The owner address of the deposit +// +// Panics: +// - If the deposit does not exist for the given lpTokenId +func GetDepositOwner(lpTokenId uint64) std.Address { + deposit := GetDeposit(lpTokenId) + + return deposit.owner +} + +// GetDepositStakeTimestamp returns the stake timestamp for a given LP token ID +// +// Parameters: +// - lpTokenId (uint64): The ID of the LP token +// +// Returns: +// - int64: The stake timestamp of the deposit +// +// Panics: +// - If the deposit does not exist for the given lpTokenId +func GetDepositStakeTimestamp(lpTokenId uint64) int64 { + deposit := GetDeposit(lpTokenId) + + return deposit.stakeTimestamp +} + +// GetDepositStakeHeight retrieves the stake height of a deposit associated with a given LP token ID. +// +// This function looks up the deposit record for the specified LP token ID and returns the height +// at which the deposit was staked. It relies on the `GetDeposit` function to fetch the deposit. +// If the deposit does not exist, the function panics. +// +// Parameters: +// - lpTokenId (uint64): The unique identifier of the LP token for which the stake height is retrieved. +// +// Returns: +// - int64: The stake height of the deposit associated with the given LP token ID. +func GetDepositStakeHeight(lpTokenId uint64) int64 { + deposit := GetDeposit(lpTokenId) + + return deposit.stakeHeight +} + +// GetDepositTargetPoolPath returns the target pool path for a given LP token ID +// +// Parameters: +// - lpTokenId (uint64): The ID of the LP token +// +// Returns: +// - string: The target pool path of the deposit +// +// Panics: +// - If the deposit does not exist for the given lpTokenId +func GetDepositTargetPoolPath(lpTokenId uint64) string { + deposit := GetDeposit(lpTokenId) + + return deposit.targetPoolPath +} + +// GetDepositTickLower retrieves the lower tick of a deposit associated with a given LP token ID. +// +// This function looks up the deposit record for the specified LP token ID and returns +// the lower tick value of the deposit. If the deposit does not exist, the function panics. +// +// Parameters: +// - lpTokenId (uint64): The unique identifier of the LP token for which the lower tick is retrieved. +// +// Returns: +// - int32: The lower tick value of the deposit. +func GetDepositTickLower(lpTokenId uint64) int32 { + deposit := GetDeposit(lpTokenId) + + return deposit.tickLower +} + +// GetDepositTickUpper retrieves the upper tick of a deposit associated with a given LP token ID. +// +// This function looks up the deposit record for the specified LP token ID and returns +// the upper tick value of the deposit. If the deposit does not exist, the function panics. +// +// Parameters: +// - lpTokenId (uint64): The unique identifier of the LP token for which the upper tick is retrieved. +// +// Returns: +// - int32: The upper tick value of the deposit. +func GetDepositTickUpper(lpTokenId uint64) int32 { + deposit := GetDeposit(lpTokenId) + + return deposit.tickUpper +} + +// GetDepositLiquidity retrieves the liquidity of a deposit associated with a given LP token ID. +// +// This function looks up the deposit record for the specified LP token ID and returns +// the liquidity value as a `*u256.Uint` object. If the deposit does not exist, the function panics. +// +// Parameters: +// - lpTokenId (uint64): The unique identifier of the LP token for which the liquidity is retrieved. +// +// Returns: +// - *u256.Uint: The liquidity value of the deposit. +func GetDepositLiquidity(lpTokenId uint64) *u256.Uint { + deposit := GetDeposit(lpTokenId) + + return deposit.liquidity +} + +// GetDepositLiquidityAsString retrieves the liquidity of a deposit associated with a given LP token ID +// and returns it as a string. +// +// This function looks up the deposit record for the specified LP token ID, retrieves its liquidity, +// and converts it to a string. If the deposit does not exist, the function panics. +// +// Parameters: +// - lpTokenId (uint64): The unique identifier of the LP token for which the liquidity is retrieved. +// +// Returns: +// - string: The liquidity value of the deposit as a string. +func GetDepositLiquidityAsString(lpTokenId uint64) string { + liquidity := GetDepositLiquidity(lpTokenId) + + return liquidity.ToString() +} + +// GetDepositLastCollectHeight retrieves the last collection height of a deposit associated with a given LP token ID. +// +// This function looks up the deposit record for the specified LP token ID and returns +// the last height at which rewards were collected. If the deposit does not exist, the function panics. +// +// Parameters: +// - lpTokenId (uint64): The unique identifier of the LP token for which the last collection height is retrieved. +// +// Returns: +// - uint64: The last collection height of the deposit. +func GetDepositLastCollectHeight(lpTokenId uint64) int64 { + deposit := GetDeposit(lpTokenId) + + return deposit.lastCollectHeight +} + +// GetDepositWarmUp retrieves the warm-up records of a deposit associated with a given LP token ID. +// +// This function looks up the deposit record for the specified LP token ID and returns +// the list of warm-up records associated with the deposit. If the deposit does not exist, the function panics. +// +// Parameters: +// - lpTokenId (uint64): The unique identifier of the LP token for which the warm-up records are retrieved. +// +// Returns: +// - []Warmup: A slice of warm-up records associated with the deposit. +func GetDepositWarmUp(lpTokenId uint64) []Warmup { + deposit := GetDeposit(lpTokenId) + + return deposit.warmups +} + +// GetPoolTier returns the tier of a given pool +// +// Parameters: +// - poolPath (string): The path of the pool +// +// Returns: +// - uint64: The tier of the pool +// +// Panics: +// - If the pool tier does not exist for the given poolPath +func GetPoolTier(poolPath string) uint64 { + return poolTier.CurrentTier(poolPath) +} + +// GetPoolTierRatio retrieves the current reward ratio for the specified pool tier. +// +// This function calculates the reward ratio of a pool by determining its tier +// using `GetPoolTier` and then fetching the tier's current ratio based on the current block height. +// +// Parameters: +// - poolPath (string): The path of the pool whose tier ratio is retrieved. +// +// Returns: +// - uint64: The current reward ratio for the specified pool tier. +func GetPoolTierRatio(poolPath string) uint64 { + tier := GetPoolTier(poolPath) + return poolTier.tierRatio.Get(tier) +} + +// GetPoolTierCount retrieves the current count of pools in the specified tier. +// +// This function checks the number of pools in the given tier based on the current block height. +// If the tier is `0`, it returns `0`. +// +// Parameters: +// - tier (uint64): The tier for which the pool count is retrieved. +// +// Returns: +// - uint64: The current count of pools in the specified tier. +func GetPoolTierCount(tier uint64) uint64 { + if tier == 0 { + return 0 + } + return uint64(poolTier.CurrentCount(tier)) +} + +// GetPoolReward retrieves the current reward amount for the specified pool tier. +// +// This function fetches the reward for the given tier based on the current block height. +// +// Parameters: +// - tier (uint64): The tier for which the reward is retrieved. +// +// Returns: +// - uint64: The current reward amount for the specified tier. +func GetPoolReward(tier uint64) uint64 { + return poolTier.CurrentReward(tier) +} + +// GetExternalIncentive retrieves an external incentive by its ID. +// +// This function looks up an external incentive in the global `externalIncentives` tree using the specified `incentiveId`. +// If the incentive does not exist, the function panics with an appropriate error message. +// +// Parameters: +// - incentiveId (string): The ID of the incentive to retrieve. +// +// Returns: +// - *ExternalIncentive: A pointer to the retrieved `ExternalIncentive` object. +func GetExternalIncentive(incentiveId string) *ExternalIncentive { + incentive, exist := externalIncentives.tree.Get(incentiveId) + if !exist { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("incentiveId(%s) incentive does not exist", incentiveId)), + ) + } + + return incentive.(*ExternalIncentive) +} + +// GetExternalIncentiveByPoolPath retrieves all external incentives associated with a specific pool path. +// +// This function iterates through all external incentives in the global `externalIncentives` tree +// and returns a list of incentives whose `targetPoolPath` matches the specified `poolPath`. +// +// Parameters: +// - poolPath (string): The path of the pool for which incentives are retrieved. +// +// Returns: +// - []ExternalIncentive: A slice of `ExternalIncentive` objects associated with the specified pool path. +func GetExternalIncentiveByPoolPath(poolPath string) []ExternalIncentive { + incentives := []ExternalIncentive{} + externalIncentives.tree.Iterate("", "", func(key string, value interface{}) bool { + incentive := value.(*ExternalIncentive) + if incentive.targetPoolPath == poolPath { + incentives = append(incentives, *incentive) + } + return false + }) + + return incentives +} + +func GetPrintExternalInfo() string { + externalDebug := ApiExternalDebugInfo{ + Height: std.GetHeight(), + Time: time.Now().Unix(), + } + + externalPositions := []ApiExternalDebugPosition{} + deposits.Iterate(uint64(0), uint64(deposits.Size()), func(tokenId uint64, deposit *Deposit) bool { + externalPosition := ApiExternalDebugPosition{ + LpTokenId: tokenId, + StakedHeight: deposit.stakeHeight, + StakedTimestamp: deposit.stakeTimestamp, + } + + externalIncentivesList := []ApiExternalDebugIncentive{} + externalIncentives.tree.Iterate("", "", func(key string, value interface{}) bool { + incentive := value.(*ExternalIncentive) + if incentive.targetPoolPath == deposit.targetPoolPath { + externalIncentive := ApiExternalDebugIncentive{ + PoolPath: incentive.targetPoolPath, + IncentiveId: key, + RewardToken: incentive.rewardToken, + RewardAmount: strconv.FormatUint(incentive.rewardAmount, 10), + RewardLeft: strconv.FormatUint(incentive.rewardLeft, 10), + StartTimestamp: incentive.startTimestamp, + EndTimestamp: incentive.endTimestamp, + RewardPerBlock: strconv.FormatUint(incentive.rewardPerBlock, 10), + Refundee: incentive.refundee, + tokenAmountFull: incentive.depositGnsAmount, + tokenAmountToGive: incentive.RewardSpent(uint64(std.GetHeight())), + StartHeight: incentive.startHeight, + EndHeight: incentive.endHeight, + } + + externalIncentivesList = append(externalIncentivesList, externalIncentive) + } + return false + }) + + externalPosition.Incentive = externalIncentivesList + externalPositions = append(externalPositions, externalPosition) + return false + }) + + externalDebug.Position = externalPositions + + // JSON Serialization + node := json.ObjectNode("", map[string]*json.Node{ + "height": json.NumberNode("", float64(externalDebug.Height)), + "time": json.NumberNode("", float64(externalDebug.Time)), + "position": json.ArrayNode("", makeExternalPositionsNode(externalDebug.Position)), + }) + + b, err := json.Marshal(node) + if err != nil { + return "JSON MARSHAL ERROR" + } + + return string(b) +} + +func makeExternalPositionsNode(positions []ApiExternalDebugPosition) []*json.Node { + externalPositions := make([]*json.Node, 0) + + for _, externalPosition := range positions { + incentives := make([]*json.Node, 0) + for _, incentive := range externalPosition.Incentive { + stakedOrExternalDuration := std.GetHeight() - max(incentive.StartHeight, externalPosition.StakedHeight) + + incentives = append(incentives, json.ObjectNode("", map[string]*json.Node{ + "poolPath": json.StringNode("poolPath", incentive.PoolPath), + "rewardToken": json.StringNode("rewardToken", incentive.RewardToken), + "rewardAmount": json.StringNode("rewardAmount", incentive.RewardAmount), + "rewardLeft": json.StringNode("rewardLeft", incentive.RewardLeft), + "startTimestamp": json.NumberNode("startTimestamp", float64(incentive.StartTimestamp)), + "endTimestamp": json.NumberNode("endTimestamp", float64(incentive.EndTimestamp)), + "rewardPerBlock": json.StringNode("rewardPerBlock", incentive.RewardPerBlock), + "stakedOrExternalDuration": json.NumberNode("stakedOrExternalDuration", float64(stakedOrExternalDuration)), + "tokenAmountFull": json.NumberNode("tokenAmountFull", float64(incentive.tokenAmountFull)), + "tokenAmountToGive": json.NumberNode("tokenAmountToGive", float64(incentive.tokenAmountToGive)), + })) + } + + externalPositions = append(externalPositions, json.ObjectNode("", map[string]*json.Node{ + "lpTokenId": json.NumberNode("lpTokenId", float64(externalPosition.LpTokenId)), + "stakedHeight": json.NumberNode("stakedHeight", float64(externalPosition.StakedHeight)), + "stakedTimestamp": json.NumberNode("stakedTimestamp", float64(externalPosition.StakedTimestamp)), + "incentive": json.ArrayNode("", incentives), + })) + } + + return externalPositions +} + +type currentExternalInfo struct { + height int64 + time int64 + externalIncentives []ExternalIncentive +} + +type ApiExternalDebugInfo struct { + Height int64 `json:"height"` + Time int64 `json:"time"` + Position []ApiExternalDebugPosition `json:"pool"` +} + +type ApiExternalDebugPosition struct { + LpTokenId uint64 `json:"lpTokenId"` + StakedHeight int64 `json:"stakedHeight"` + StakedTimestamp int64 `json:"stakedTimestamp"` + Incentive []ApiExternalDebugIncentive `json:"incentive"` +} + +type ApiExternalDebugIncentive struct { + PoolPath string `json:"poolPath"` + IncentiveId string `json:"incentiveId"` + RewardToken string `json:"rewardToken"` + RewardAmount string `json:"rewardAmount"` + RewardLeft string `json:"rewardLeft"` + StartTimestamp int64 `json:"startTimestamp"` + EndTimestamp int64 `json:"endTimestamp"` + RewardPerBlockX96 string `json:"rewardPerBlockX96"` + RewardPerBlock string `json:"rewardPerBlock"` + Refundee std.Address `json:"refundee"` + StartHeight int64 `json:"startHeight"` + EndHeight int64 `json:"endHeight"` + // FROM positionExternal -> externalRewards + tokenAmountX96 *u256.Uint `json:"tokenAmountX96"` + tokenAmount uint64 `json:"tokenAmount"` + tokenAmountFull uint64 `json:"tokenAmountFull"` + tokenAmountToGive uint64 `json:"tokenAmountToGive"` + // FROM externalWarmUpAmount + full30 uint64 `json:"full30"` + give30 uint64 `json:"give30"` + full50 uint64 `json:"full50"` + give50 uint64 `json:"give50"` + full70 uint64 `json:"full70"` + give70 uint64 `json:"give70"` + full100 uint64 `json:"full100"` +} + +// DEBUG INTERNAL (GNS EMISSION) +type currentInfo struct { + height int64 + time int64 + gnsStaker uint64 + gnsDevOps uint64 + gnsCommunityPool uint64 + gnsGovStaker uint64 + gnsProtocolFee uint64 + gnsADMIN uint64 +} + +func getCurrentInfo() currentInfo { + return currentInfo{ + height: std.GetHeight(), + time: time.Now().Unix(), + gnsStaker: gns.BalanceOf(consts.STAKER_ADDR), + gnsDevOps: gns.BalanceOf(consts.DEV_OPS), + gnsCommunityPool: gns.BalanceOf(consts.COMMUNITY_POOL_ADDR), + gnsGovStaker: gns.BalanceOf(consts.GOV_STAKER_ADDR), + gnsProtocolFee: gns.BalanceOf(consts.PROTOCOL_FEE_ADDR), + gnsADMIN: gns.BalanceOf(consts.ADMIN), + } +} diff --git a/contract/r/gnoswap/staker/gno.mod b/contract/r/gnoswap/staker/gno.mod new file mode 100644 index 000000000..0f1a1142b --- /dev/null +++ b/contract/r/gnoswap/staker/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/staker diff --git a/contract/r/gnoswap/staker/incentive_id.gno b/contract/r/gnoswap/staker/incentive_id.gno new file mode 100644 index 000000000..567aef1f4 --- /dev/null +++ b/contract/r/gnoswap/staker/incentive_id.gno @@ -0,0 +1,115 @@ +package staker + +import ( + "encoding/base64" + "std" + + "gno.land/p/demo/ufmt" +) + +// incentiveIdCompute generates a unique incentive ID by combining caller address, +// target pool path, reward token, and time/height information. +// +// The generated ID is a Base64-encoded string of the following format: +// "caller:targetPoolPath:rewardToken:startTimestamp:endTimestamp:height". +// +// Parameters: +// - caller (std.Address): The address of the caller initiating the incentive. +// - targetPoolPath (string): The target pool path for the incentive. +// - rewardToken (string): The reward token associated with the incentive. +// - startTimestamp (int64): The starting timestamp of the incentive. +// - endTimestamp (int64): The ending timestamp of the incentive. +// - height (int64): The blockchain height at which the incentive is created. +// +// Returns: +// - string: A Base64-encoded string representing the unique incentive ID. +// +// Example: +// Input: caller="g1xyz", targetPoolPath="pool1", rewardToken="gns", startTimestamp=12345, endTimestamp=67890, height=1000 +// Output: "ZzF4eXo6cG9vbDE6Z25zOjEyMzQ1OjY3ODkwOjEwMDA=" +func incentiveIdCompute(caller std.Address, targetPoolPath, rewardToken string, startTimestamp, endTimestamp, height int64) string { + key := ufmt.Sprintf("%s:%s:%s:%d:%d:%d", caller.String(), targetPoolPath, rewardToken, startTimestamp, endTimestamp, height) + + encoded := base64.StdEncoding.EncodeToString([]byte(key)) + return encoded +} + +// incentiveIdByTime generates a unique incentive ID based on time intervals, +// creator address, and reward token. +// +// The generated ID is a plain string in the following format: +// "startTime:endTime:creator:rewardToken". +// +// Parameters: +// - startTime (uint64): The starting time of the incentive. +// - endTime (uint64): The ending time of the incentive. +// - creator (std.Address): The address of the incentive creator. +// - rewardToken (string): The reward token associated with the incentive. +// +// Returns: +// - string: A plain string representing the incentive ID. +// +// Example: +// Input: startTime=12345, endTime=67890, creator="g1xyz", rewardToken="gns" +// Output: "000000000000012345:000000000000067890:g1xyz:gns" +func incentiveIdByTime(startTime, endTime int64, creator std.Address, rewardToken string) string { + startTimeEncode := EncodeUint(uint64(startTime)) + endTimeEncode := EncodeUint(uint64(endTime)) + creatorEncode := creator.String() + + byTimeId := []byte(startTimeEncode) + byTimeId = append(byTimeId, ':') + byTimeId = append(byTimeId, endTimeEncode...) + byTimeId = append(byTimeId, ':') + byTimeId = append(byTimeId, creatorEncode...) + byTimeId = append(byTimeId, ':') + byTimeId = append(byTimeId, rewardToken...) + + return string(byTimeId) +} + +// incentiveIdByHeight generates two unique incentive IDs based on height intervals, +// creator address, and reward token. +// +// The first ID (byHeightId) has the format: +// "startHeight:endHeight:creator:rewardToken". +// +// The second ID (byCreatorId) has the format: +// "creator:startHeight:endHeight:rewardToken". +// +// Parameters: +// - startHeight (uint64): The starting blockchain height of the incentive. +// - endHeight (uint64): The ending blockchain height of the incentive. +// - creator (std.Address): The address of the incentive creator. +// - rewardToken (string): The reward token associated with the incentive. +// +// Returns: +// - string: The first ID (byHeightId). +// - string: The second ID (byCreatorId). +// +// Example: +// Input: startHeight=123, endHeight=456, creator="g1xyz", rewardToken="gns" +// Output: "000000000000000123:000000000000000456:g1xyz:gns", "g1xyz:000000000000000123:000000000000000456:gns" +func incentiveIdByHeight(startHeight, endHeight int64, creator std.Address, rewardToken string) (string, string) { + startHeightEncode := EncodeUint(uint64(startHeight)) + endHeightEncode := EncodeUint(uint64(endHeight)) + creatorEncode := creator.String() + + byHeightId := []byte(startHeightEncode) + byHeightId = append(byHeightId, ':') + byHeightId = append(byHeightId, endHeightEncode...) + byHeightId = append(byHeightId, ':') + byHeightId = append(byHeightId, creatorEncode...) + byHeightId = append(byHeightId, ':') + byHeightId = append(byHeightId, rewardToken...) + + byCreatorId := []byte(creatorEncode) + byCreatorId = append(byCreatorId, ':') + byCreatorId = append(byCreatorId, startHeightEncode...) + byCreatorId = append(byCreatorId, ':') + byCreatorId = append(byCreatorId, endHeightEncode...) + byCreatorId = append(byCreatorId, ':') + byCreatorId = append(byCreatorId, rewardToken...) + + return string(byHeightId), string(byCreatorId) +} \ No newline at end of file diff --git a/contract/r/gnoswap/staker/incentive_id_test.gno b/contract/r/gnoswap/staker/incentive_id_test.gno new file mode 100644 index 000000000..2f54572aa --- /dev/null +++ b/contract/r/gnoswap/staker/incentive_id_test.gno @@ -0,0 +1,56 @@ +package staker + +import ( + "encoding/base64" + "std" + "testing" +) + +func TestIncentiveIdCompute(t *testing.T) { + caller := std.Address("g1xyz") + targetPoolPath := "pool1" + rewardToken := "gns" + startTimestamp := int64(12345) + endTimestamp := int64(67890) + height := int64(1000) + + expected := base64.StdEncoding.EncodeToString([]byte("g1xyz:pool1:gns:12345:67890:1000")) + actual := incentiveIdCompute(caller, targetPoolPath, rewardToken, startTimestamp, endTimestamp, height) + + if actual != expected { + t.Errorf("incentiveIdCompute() = %s; want %s", actual, expected) + } +} + +func TestIncentiveIdByTime(t *testing.T) { + startTime := int64(12345) + endTime := int64(67890) + creator := std.Address("g1xyz") + rewardToken := "gns" + + expected := "00000000000000012345:00000000000000067890:g1xyz:gns" + actual := incentiveIdByTime(startTime, endTime, creator, rewardToken) + + if actual != expected { + t.Errorf("incentiveIdByTime() = %s; want %s", actual, expected) + } +} + +func TestIncentiveIdByHeight(t *testing.T) { + startHeight := int64(123) + endHeight := int64(456) + creator := std.Address("g1xyz") + rewardToken := "gns" + + expectedByHeight := "00000000000000000123:00000000000000000456:g1xyz:gns" + expectedByCreator := "g1xyz:00000000000000000123:00000000000000000456:gns" + + actualByHeight, actualByCreator := incentiveIdByHeight(startHeight, endHeight, creator, rewardToken) + + if actualByHeight != expectedByHeight { + t.Errorf("incentiveIdByHeight() byHeightId = %s; want %s", actualByHeight, expectedByHeight) + } + if actualByCreator != expectedByCreator { + t.Errorf("incentiveIdByHeight() byCreatorId = %s; want %s", actualByCreator, expectedByCreator) + } +} diff --git a/contract/r/gnoswap/staker/manage_pool_tier_and_warmup.gno b/contract/r/gnoswap/staker/manage_pool_tier_and_warmup.gno new file mode 100644 index 000000000..e0aeafa57 --- /dev/null +++ b/contract/r/gnoswap/staker/manage_pool_tier_and_warmup.gno @@ -0,0 +1,182 @@ +package staker + +import ( + "std" + + "gno.land/p/demo/ufmt" + + "gno.land/r/gnoswap/v1/common" + + en "gno.land/r/gnoswap/v1/emission" + pl "gno.land/r/gnoswap/v1/pool" +) + +const ( + NOT_EMISSION_TARGET_TIER uint64 = 0 +) + +func SetPoolTierByAdmin(poolPath string, tier uint64) { + assertOnlyAdmin() + assertMustNotHalted() + assertPoolMustExist(poolPath) + + setPoolTier(poolPath, tier) +} + +func SetPoolTier(poolPath string, tier uint64) { + assertOnlyGovernance() + assertMustNotHalted() + assertPoolMustExist(poolPath) + + setPoolTier(poolPath, tier) +} + +func setPoolTier(poolPath string, tier uint64) { + en.MintAndDistributeGns() + + currentHeight := std.GetHeight() + pools.GetOrCreate(poolPath) + poolTier.changeTier(currentHeight, pools, poolPath, tier) + + prevAddr, prevRealm := getPrev() + std.Emit( + "SetPoolTier", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "poolPath", poolPath, + "tier", formatUint(tier), + "height", formatInt(currentHeight), + ) +} + +func ChangePoolTierByAdmin(poolPath string, tier uint64) { + assertOnlyAdmin() + assertMustNotHalted() + assertPoolMustExist(poolPath) + + changePoolTier(poolPath, tier) +} + +func ChangePoolTier(poolPath string, tier uint64) { + assertOnlyGovernance() + assertMustNotHalted() + assertPoolMustExist(poolPath) + + changePoolTier(poolPath, tier) +} + +func changePoolTier(poolPath string, tier uint64) { + en.MintAndDistributeGns() + + prevTier := poolTier.CurrentTier(poolPath) + currentHeight := std.GetHeight() + poolTier.changeTier(currentHeight, pools, poolPath, tier) + + prevAddr, prevRealm := getPrev() + std.Emit( + "ChangePoolTier", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "poolPath", poolPath, + "prevTier", formatUint(prevTier), + "newTier", formatUint(tier), + "height", formatInt(currentHeight), + ) +} + +func RemovePoolTierByAdmin(poolPath string) { + assertOnlyAdmin() + assertMustNotHalted() + assertPoolMustExist(poolPath) + + removePoolTier(poolPath) +} + +func RemovePoolTier(poolPath string) { + assertOnlyGovernance() + assertMustNotHalted() + assertPoolMustExist(poolPath) + + removePoolTier(poolPath) +} + +func removePoolTier(poolPath string) { + en.MintAndDistributeGns() + + currentHeight := std.GetHeight() + poolTier.changeTier(currentHeight, pools, poolPath, NOT_EMISSION_TARGET_TIER) + + prevAddr, prevRealm := getPrev() + std.Emit( + "RemovePoolTier", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "poolPath", poolPath, + "height", formatInt(currentHeight), + ) +} + +func SetWarmUpByAdmin(pct, blockDuration int64) { + assertOnlyAdmin() + assertMustNotHalted() + + setWarmUp(pct, blockDuration) +} + +func SetWarmUp(pct, blockDuration int64) { + assertOnlyGovernance() + assertMustNotHalted() + + setWarmUp(pct, blockDuration) +} + +func setWarmUp(pct, blockDuration int64) { + en.MintAndDistributeGns() + + modifyWarmup(pctToIndex(pct), blockDuration) + + prevAddr, prevRealm := getPrev() + std.Emit( + "SetWarmUp", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "pct", formatUint(pct), + "blockDuration", formatInt(blockDuration), + ) +} + +func pctToIndex(pct int64) int { + switch pct { + case 30: + return 0 + case 50: + return 1 + case 70: + return 2 + case 100: + return 3 + default: + panic("staker.gno__pctToIndex() || pct is not valid") + } +} + +func assertPoolMustExist(poolPath string) { + if !(pl.DoesPoolPathExist(poolPath)) { + panic(addDetailToError( + errInvalidPoolPath, + ufmt.Sprintf("pool(%s) does not exist", poolPath), + )) + } +} + +func assertOnlyAdmin() { + common.AdminOnly(getPrevAddr()) +} + +func assertOnlyGovernance() { + common.GovernanceOnly(getPrevAddr()) +} + +func assertMustNotHalted() { + common.IsHalted() +} diff --git a/contract/r/gnoswap/staker/mint_stake.gno b/contract/r/gnoswap/staker/mint_stake.gno new file mode 100644 index 000000000..297f88b26 --- /dev/null +++ b/contract/r/gnoswap/staker/mint_stake.gno @@ -0,0 +1,74 @@ +package staker + +import ( + "std" + + pn "gno.land/r/gnoswap/v1/position" + + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + + "gno.land/r/gnoswap/v1/gnft" + + "gno.land/p/demo/grc/grc721" +) + +// MintAndStake mints LP tokens and stakes them in a single transaction. +// Returns tokenId, liquidity, amount0, amount1, poolPath +// ref: https://docs.gnoswap.io/contracts/staker/mint_stake.gno#mintandstake +func MintAndStake( + token0 string, + token1 string, + fee uint32, + tickLower int32, + tickUpper int32, + amount0Desired string, // *u256.Uint + amount1Desired string, // *u256.Uint + amount0Min string, // *u256.Uint + amount1Min string, // *u256.Uint + deadline int64, +) (uint64, string, string, string, string) { + + // if one click native + if token0 == consts.GNOT || token1 == consts.GNOT { + // check sent ugnot + sent := std.GetOrigSend() + ugnotSent := uint64(sent.AmountOf("ugnot")) + + // not enough ugnot sent + if ugnotSent < consts.UGNOT_MIN_DEPOSIT_TO_WRAP { + panic(addDetailToError( + errWugnotMinimum, + ufmt.Sprintf("mint_stake.gno__MintAndStake() || too less ugnot sent(%d), minimum:%d", ugnotSent, consts.UGNOT_MIN_DEPOSIT_TO_WRAP), + )) + } + + // send it over to position to wrap + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(consts.STAKER_ADDR, consts.POSITION_ADDR, std.Coins{{Denom: "ugnot", Amount: int64(ugnotSent)}}) + } + + tokenId, liquidity, amount0, amount1 := pn.Mint( + token0, + token1, + fee, + tickLower, + tickUpper, + amount0Desired, + amount1Desired, + amount0Min, + amount1Min, + deadline, + consts.STAKER_ADDR, + std.PrevRealm().Addr(), + ) + + // at this point, staker has minted token + toTid := grc721.TokenID(ufmt.Sprintf("%d", tokenId)) + gnft.SetTokenURIByImageURI(toTid) + + poolPath, _, _ := StakeToken(tokenId) // poolPath, stakedAmount0, stakedAmount1 + + return tokenId, liquidity, amount0, amount1, poolPath +} diff --git a/contract/r/gnoswap/staker/protocol_fee_unstaking.gno b/contract/r/gnoswap/staker/protocol_fee_unstaking.gno new file mode 100644 index 000000000..f42ef7708 --- /dev/null +++ b/contract/r/gnoswap/staker/protocol_fee_unstaking.gno @@ -0,0 +1,160 @@ +package staker + +import ( + "std" + + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/gns" + + pf "gno.land/r/gnoswap/v1/protocol_fee" +) + +var ( + unstakingFee = uint64(100) // 1% +) + +// handleUnStakingFee calculates and applies the unstaking fee. +// +// The function deducts a fee from the unstaked amount based on the `unstakingFee` rate, +// sends the fee to the protocol fee address, and emits an event indicating the fee transfer. +// +// Parameters: +// - tokenPath (string): The token path (e.g., the contract managing the token). +// - amount (uint64): The total unstaked amount. +// - internal (bool): Indicates if the fee is for internal or external use. +// - tokenId (uint64): The token ID associated with the unstaking action. +// - poolPath (string): The pool path related to the unstaking. +// +// Returns: +// - uint64: The amount after deducting the unstaking fee. +func handleUnStakingFee( + tokenPath string, + amount uint64, + internal bool, + tokenId uint64, + poolPath string, +) uint64 { + if unstakingFee == 0 { + return amount + } + + feeAmount := amount * unstakingFee / 10000 + if feeAmount == 0 { + return amount + } + + prevAddr, prevPkgPath := getPrev() + + if internal { + // staker contract has fee + gns.Transfer(consts.PROTOCOL_FEE_ADDR, feeAmount) + pf.AddToProtocolFee(consts.GNS_PATH, feeAmount) + + std.Emit( + "ProtocolFeeInternalReward", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "fromPositionId", formatUint(tokenId), + "fromPoolPath", poolPath, + "feeTokenPath", consts.GNS_PATH, + "feeAmount", formatUint(feeAmount), + ) + } else { + // external contract has fee + teller := common.GetTokenTeller(tokenPath) + teller.Transfer(consts.PROTOCOL_FEE_ADDR, feeAmount) + pf.AddToProtocolFee(tokenPath, feeAmount) + + std.Emit( + "ProtocolFeeExternalReward", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "fromPositionId", formatUint(tokenId), + "fromPoolPath", poolPath, + "feeTokenPath", tokenPath, + "feeAmount", formatUint(feeAmount), + ) + } + + return amount - feeAmount +} + +// GetUnstakingFee returns current rate of unstaking fee +// ref: https://docs.gnoswap.io/contracts/staker/protocol_fee_unstaking.gno#getunstakingfee +func GetUnstakingFee() uint64 { + return unstakingFee +} + +// SetUnStakingFeeByAdmin sets the unstaking fee rate by an admin. +// +// This function ensures that only admins can modify the unstaking fee. It validates +// the input fee and emits an event indicating the change. +// +// Parameters: +// - fee (uint64): The new unstaking fee rate in basis points (bps). +// +// Panics: +// - If the caller is not an admin. +func SetUnStakingFeeByAdmin(fee uint64) { + caller := std.PrevRealm().Addr() + if err := common.AdminOnly(caller); err != nil { + panic(err.Error()) + } + + prevUnStakingFee := getUnStakingFee() + + setUnStakingFee(fee) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "SetUnStakingFeeByAdmin", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevFee", formatUint(prevUnStakingFee), + "newFee", formatUint(fee), + ) +} + +// SetUnStakingFee modifies the unstaking fee +// Only governance contract can execute this function via proposal +// ref: https://docs.gnoswap.io/contracts/staker/protocol_fee_unstaking.gno#setunstakingfee +func SetUnStakingFee(fee uint64) { + caller := getPrevAddr() + if err := common.GovernanceOnly(caller); err != nil { + panic(addDetailToError(err.Error())) + } + + prevUnStakingFee := getUnStakingFee() + + setUnStakingFee(fee) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "SetUnStakingFee", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "prevFee", formatUint(prevUnStakingFee), + "newFee", formatUint(fee), + ) +} + +func setUnStakingFee(fee uint64) { + // 10000 (bps) = 100% + if fee > 10000 { + panic(addDetailToError( + errInvalidUnstakingFee, + ufmt.Sprintf("reward_fee.gno__SetUnstakingFee() || fee(%d) must be in range 0 ~ 10000", fee), + )) + } + + unstakingFee = fee +} + +func getUnStakingFee() uint64 { + return unstakingFee +} diff --git a/contract/r/gnoswap/staker/protocol_fee_unstaking_test.gno b/contract/r/gnoswap/staker/protocol_fee_unstaking_test.gno new file mode 100644 index 000000000..6229ac58f --- /dev/null +++ b/contract/r/gnoswap/staker/protocol_fee_unstaking_test.gno @@ -0,0 +1,80 @@ +package staker + +import ( + "testing" +) + +func TestHandleUnstakingFee(t *testing.T) { + tests := []struct { + name string + tokenPath string + amount uint64 + internal bool + tokenId uint64 + poolPath string + expectedFee uint64 + expectedNet uint64 + }{ + { + name: "No fee configured", + tokenPath: gnsPath, + amount: 10000, + internal: true, + tokenId: 1, + poolPath: "pool1", + expectedFee: 0, + expectedNet: 10000, + }, + { + name: "Standard fee", + tokenPath: gnsPath, + amount: 10000, + internal: false, + tokenId: 1, + poolPath: "pool1", + expectedFee: 100, + expectedNet: 9900, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + unstakingFee = tc.expectedFee // Set the fee globally for the test + + netAmount := handleUnStakingFee(tc.tokenPath, tc.amount, tc.internal, tc.tokenId, tc.poolPath) + + if netAmount != tc.expectedNet { + t.Errorf("Expected netAmount %d, got %d", tc.expectedNet, netAmount) + } + }) + } +} + +func TestSetUnstakingFee(t *testing.T) { + tests := []struct { + name string + fee uint64 + shouldPanic bool + }{ + {name: "Valid fee", fee: 500, shouldPanic: false}, + {name: "Excessive fee", fee: 10001, shouldPanic: true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil && !tc.shouldPanic { + t.Errorf("Unexpected panic for fee %d: %v", tc.fee, r) + } + }() + + setUnStakingFee(tc.fee) + + if !tc.shouldPanic { + if unstakingFee != tc.fee { + t.Errorf("Expected fee %d, got %d", tc.fee, unstakingFee) + } + } + }) + } +} diff --git a/contract/r/gnoswap/staker/query.gno b/contract/r/gnoswap/staker/query.gno new file mode 100644 index 000000000..98f73fcce --- /dev/null +++ b/contract/r/gnoswap/staker/query.gno @@ -0,0 +1,125 @@ +package staker + +import ( + "std" + + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" +) + +type PoolData struct { + PoolPath string + Tier uint64 + ActiveIncentives []string + StakedLiquidity *u256.Uint +} + +type IncentiveData struct { + IncentiveID string + StartTimestamp int64 + EndTimestamp int64 + RewardToken string + RewardAmount *u256.Uint + Refundee std.Address + PoolPath string +} + +type DepositData struct { + TokenID uint64 + Owner std.Address + TargetPoolPath string + StakeTimestamp int64 + Liquidity *u256.Uint + WarmupCount int +} + +// QueryPoolData returns combined pool data including tier, incentives and current staked liquidity +func QueryPoolData(poolPath string) (*PoolData, error) { + pool, exist := pools.Get(poolPath) + if !exist { + return nil, ufmt.Errorf("pool %s not found", poolPath) + } + + currentHeight := std.GetHeight() + tier := poolTier.CurrentTier(poolPath) + + ictvIds := filterActiveIncentives(pool, currentHeight) + + return &PoolData{ + PoolPath: poolPath, + Tier: tier, + ActiveIncentives: ictvIds, + StakedLiquidity: pool.CurrentStakedLiquidity(std.GetHeight()), + }, nil +} + +// QueryIncentiveData returns detailed information about a specific incentive +func QueryIncentiveData(incentiveId string) (*IncentiveData, error) { + var found bool + var data IncentiveData + + pools.tree.Iterate("", "", func(key string, value interface{}) bool { + pool := value.(*Pool) + + pool.incentives.byTime.Iterate("", "", func(key string, value interface{}) bool { + if key == incentiveId { + ictv := value.(*ExternalIncentive) + data = IncentiveData{ + IncentiveID: incentiveId, + StartTimestamp: ictv.startTimestamp, + EndTimestamp: ictv.endTimestamp, + RewardToken: ictv.rewardToken, + RewardAmount: u256.NewUint(ictv.rewardAmount), + Refundee: ictv.refundee, + PoolPath: pool.poolPath, + } + found = true + return true + } + return false + }) + + return found + }) + + if !found { + return nil, ufmt.Errorf("incentiveId(%s) incentive does not exist", incentiveId) + } + + return &data, nil +} + +// QueryDepositData returns detailed information about a specific deposit +func QueryDepositData(lpTokenId uint64) (*DepositData, error) { + deposit := deposits.Get(lpTokenId) + if deposit == nil { + return nil, ufmt.Errorf("tokenId(%d) deposit does not exist", lpTokenId) + } + + return &DepositData{ + TokenID: lpTokenId, + Owner: deposit.owner, + TargetPoolPath: deposit.targetPoolPath, + StakeTimestamp: deposit.stakeTimestamp, + Liquidity: deposit.liquidity, + WarmupCount: len(deposit.warmups), + }, nil +} + +func filterActiveIncentives(pool *Pool, currentHeight int64) []string { + ictvIds := make([]string, 0) + pool.incentives.byHeight.Iterate("", "", func(key string, value interface{}) bool { + ictv := value.(*ExternalIncentive) + if ictv.startHeight <= currentHeight && currentHeight < ictv.endHeight { + ictvId := incentiveIdByTime( + ictv.startHeight, + ictv.endHeight, + ictv.refundee, + ictv.rewardToken, + ) + ictvIds = append(ictvIds, ictvId) + } + return false + }) + return ictvIds +} diff --git a/contract/r/gnoswap/staker/query_test.gno b/contract/r/gnoswap/staker/query_test.gno new file mode 100644 index 000000000..0a8c3ba97 --- /dev/null +++ b/contract/r/gnoswap/staker/query_test.gno @@ -0,0 +1,82 @@ +package staker + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestQueryPoolData(t *testing.T) { + testAddr := testutils.TestAddress("test_address") + poolPath := "test_pool" + + pool := NewPool(poolPath, std.GetHeight()) + pools.Set(poolPath, pool) + + // Add test incentive to pool + startTime := time.Now().Unix() + endTime := startTime + 3600 // 1 hour + rewardToken := "TEST_TOKEN" + rewardAmount := uint64(1000) + + incentive := &ExternalIncentive{ + startTimestamp: startTime, + endTimestamp: endTime, + startHeight: std.GetHeight(), + endHeight: std.GetHeight() + 100, + rewardToken: rewardToken, + rewardAmount: rewardAmount, + refundee: testAddr, + } + + pool.incentives.create(testAddr, incentive) + + data, err := QueryPoolData(poolPath) + if err != nil { + t.Errorf("QueryPoolData failed: %v", err) + } + + uassert.Equal(t, data.PoolPath, poolPath) + uassert.Equal(t, len(data.ActiveIncentives), 1) + + _, err = QueryPoolData("non_existent_pool") + uassert.Error(t, err) + + // reset + pools.tree.Remove(poolPath) +} + +func TestQueryDepositData(t *testing.T) { + testAddr := testutils.TestAddress("test_address") + poolPath := "test_pool" + lpTokenId := uint64(1) + + deposit := &Deposit{ + owner: testAddr, + targetPoolPath: poolPath, + stakeTimestamp: time.Now().Unix(), + liquidity: u256.NewUint(1000), + warmups: make([]Warmup, 2), + } + + deposits.Set(lpTokenId, deposit) + + // Test query + data, err := QueryDepositData(lpTokenId) + if err != nil { + t.Errorf("QueryDepositData failed: %v", err) + } + + uassert.Equal(t, data.TokenID, lpTokenId) + uassert.Equal(t, data.Owner, testAddr) + uassert.Equal(t, data.TargetPoolPath, poolPath) + uassert.Equal(t, data.StakeTimestamp, deposit.stakeTimestamp) + uassert.Equal(t, data.WarmupCount, 2) + + // reset + deposits.Remove(lpTokenId) +} diff --git a/contract/r/gnoswap/staker/reward_calculation.gno b/contract/r/gnoswap/staker/reward_calculation.gno new file mode 100644 index 000000000..7f1941dc7 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation.gno @@ -0,0 +1,61 @@ +package staker + +import ( + "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +// liquidityMathAddDelta calculates the new liquidity by applying the delta liquidity to the current liquidity. +// If delta liquidity is negative, it subtracts the absolute value of delta liquidity from the current liquidity. +// If delta liquidity is positive, it adds the absolute value of delta liquidity to the current liquidity. +// +// Parameters: +// - x: The current liquidity as a uint256 value. +// - y: The delta liquidity as a signed int256 value. +// +// Returns: +// - The new liquidity as a uint256 value. +// +// Notes: +// - If `x` or `y` is nil, the function panics with an appropriate error message. +// - If `y` is negative, its absolute value is subtracted from `x`. +// - The result must be less than `x`. Otherwise, the function panics to prevent underflow. +// +// - If `y` is positive, it is added to `x`. +// - The result must be greater than or equal to `x`. Otherwise, the function panics to prevent overflow. +// +// - The function ensures correctness by validating the results of the arithmetic operations. +func liquidityMathAddDelta(x *u256.Uint, y *i256.Int) *u256.Uint { + if x == nil || y == nil { + panic(addDetailToError( + errInvalidInput, + "x or y is nil", + )) + } + + var z *u256.Uint + + // Subtract or add based on the sign of y + if y.Lt(i256.Zero()) { + absDelta := y.Abs() + z = new(u256.Uint).Sub(x, absDelta) + if z.Gte(x) { + panic(addDetailToError( + errCalculationError, + ufmt.Sprintf("Condition failed: (z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + )) + } + } else { + z = new(u256.Uint).Add(x, y.Abs()) + if z.Lt(x) { + panic(addDetailToError( + errCalculationError, + ufmt.Sprintf("Condition failed: (z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + )) + } + } + + return z +} diff --git a/contract/r/gnoswap/staker/reward_calculation_canonical_env_test.gno b/contract/r/gnoswap/staker/reward_calculation_canonical_env_test.gno new file mode 100644 index 000000000..e71afef13 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_canonical_env_test.gno @@ -0,0 +1,625 @@ +package staker + +// "Canonical" implementation of reward calculation. +// Used for testing and reference. + +import ( + "errors" + "math" + "std" + "testing" + + ufmt "gno.land/p/demo/ufmt" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + en "gno.land/r/gnoswap/v1/emission" +) + +type canonicalPool struct { + poolPath string + tier uint64 + tick int32 + incentive []*ExternalIncentive + + tickCrossHook func(poolPath string, tickId int32, zeroForOne bool) +} + +func (self *canonicalPool) InternalReward(emission uint64, ratio TierRatio, count uint64) uint64 { + switch self.tier { + case 0: + return 0 + case 1: + return emission * ratio.Tier1 / count / 100 + case 2: + return emission * ratio.Tier2 / count / 100 + case 3: + return emission * ratio.Tier3 / count / 100 + default: + panic("invalid tier") + } +} + +func (self *canonicalPool) ExternalReward(currentHeight int64) map[string]uint64 { + reward := make(map[string]uint64) + + for _, incentive := range self.incentive { + if incentive.startHeight > int64(currentHeight) || incentive.endHeight < int64(currentHeight) { + continue + } + + reward[incentive.rewardToken] += incentive.rewardPerBlock + } + + return reward +} + +type canonicalRewardState struct { + t *testing.T + + global *emulatedGlobalState + + Pool map[string]*canonicalPool + tickCrossHook func(poolPath string, tickId int32, zeroForOne bool) + + Reward []map[uint64]Reward // blockNumber -> depositId -> reward + + emissionUpdates *UintTree + + MsPerBlock int64 + PerBlockEmission uint64 + + CurrentTimestamp int64 + + // emulated reward claimed by unstake + emulatedClaimedReward map[uint64]uint64 +} + +func NewCanonicalRewardState(t *testing.T, pools *Pools, deposits *Deposits, tickCrossHook func(pools *Pools, height func() int64) func(poolPath string, tickId int32, zeroForOne bool)) *canonicalRewardState { + std.TestSkipHeights(1) + + result := &canonicalRewardState{ + t: t, + global: &emulatedGlobalState{ + poolTier: poolTier, + pools: pools, + deposits: deposits, + }, + Pool: make(map[string]*canonicalPool), + Reward: make([]map[uint64]Reward, 124), + emissionUpdates: NewUintTree(), + MsPerBlock: 1000, + PerBlockEmission: 1000000000, + CurrentTimestamp: 0, + + emulatedClaimedReward: make(map[uint64]uint64), + } + result.tickCrossHook = tickCrossHook(pools, func() int64 { + return int64(result.CurrentHeight()) + }) + + // Technically, the block number being provided to the NewPoolTier should be 124, but because of duplicate creation of the default pool, we put 123 here to just make rewardCache work for the testing. + result.global.poolTier = NewPoolTier(pools, 123, test_gnousdc, func() uint64 { return result.PerBlockEmission }, func(start, end int64) ([]int64, []uint64) { + heights := make([]int64, 0) + emissions := make([]uint64, 0) + result.emissionUpdates.Iterate(start, end, func(key int64, value interface{}) bool { + heights = append(heights, key) + emissions = append(emissions, value.(uint64)) + return false + }) + return heights, emissions + }) + + result.NextBlock() // must skip height 0 + + result.SetEmissionUpdate(1000000000) + + if std.GetHeight() != result.CurrentHeight() { + panic(ufmt.Sprintf("height mismatch: %d != %d", std.GetHeight(), result.CurrentHeight())) + } + + return result +} + +type emulatedGlobalState struct { + poolTier *PoolTier + pools *Pools + deposits *Deposits +} + +func (self *canonicalRewardState) isInRange(deposit *Deposit) bool { + tick := self.Pool[deposit.targetPoolPath].tick + return deposit.tickLower <= tick && tick < deposit.tickUpper +} + +func (self *canonicalRewardState) SetEmissionUpdate(emission uint64) { + self.emissionUpdates.Set(self.CurrentHeight(), emission) + self.PerBlockEmission = emission +} + +func (self *canonicalRewardState) LiquidityPerPool() (map[string]*u256.Uint) { + liquidity := make(map[string]*u256.Uint) + self.global.deposits.Iterate(0, math.MaxUint64, func(tokenId uint64, deposit *Deposit) bool { + if !self.isInRange(deposit) { + return false + } + + poolLiquidity, ok := liquidity[deposit.targetPoolPath] + if !ok { + poolLiquidity = u256.Zero() + } + + poolLiquidity = poolLiquidity.Add(poolLiquidity, deposit.liquidity) + liquidity[deposit.targetPoolPath] = poolLiquidity + return false + }) + + return liquidity +} + +func (self *canonicalRewardState) InternalRewardPerPool(emission uint64) map[string]uint64 { + reward := make(map[string]uint64) + tierCount := []uint64{0, 0, 0, 0} + + for _, pool := range self.Pool { + tierCount[pool.tier]++ + } + ratio := TierRatioFromCounts(tierCount[1], tierCount[2], tierCount[3]) + + for _, pool := range self.Pool { + reward[pool.poolPath] = pool.InternalReward(emission, ratio, tierCount[pool.tier]) + } + + return reward +} + +func (self *canonicalRewardState) ExternalRewardPerPool(currentHeight int64) map[string]map[string]uint64 { + reward := make(map[string]map[string]uint64) + + for _, pool := range self.Pool { + reward[pool.poolPath] = pool.ExternalReward(currentHeight) + } + + return reward +} + +func (self *canonicalRewardState) CurrentHeight() int64 { + return int64(len(self.Reward)) // due to testing requirement +} + +// Process block with canonical reward calculation +func (self *canonicalRewardState) CalculateCanonicalReward() map[uint64]Reward { + currentHeight := self.CurrentHeight() + 1 + rewards := make(map[uint64]Reward) + + liquidityPerPool := self.LiquidityPerPool() + internalRewardPerPool := self.InternalRewardPerPool(self.PerBlockEmission) + externalRewardPerPool := self.ExternalRewardPerPool(int64(currentHeight)) + + self.global.deposits.Iterate(0, math.MaxUint64, func(tokenId uint64, deposit *Deposit) bool { + if !self.isInRange(deposit) { + return false + } + + warmup := deposit.warmups[deposit.FindWarmup(int64(currentHeight))] + internal, internalPenalty := warmup.Apply(internalRewardPerPool[deposit.targetPoolPath], deposit.liquidity, liquidityPerPool[deposit.targetPoolPath]) + poolExternals := externalRewardPerPool[deposit.targetPoolPath] + externals := make(map[string]uint64) + externalPenalties := make(map[string]uint64) + for key, value := range poolExternals { + external, externalPenalty := warmup.Apply(value, deposit.liquidity, liquidityPerPool[deposit.targetPoolPath]) + externals[key] = external + externalPenalties[key] = externalPenalty + } + rewards[tokenId] = Reward{ + Internal: internal, + External: externals, + InternalPenalty: internalPenalty, + ExternalPenalty: externalPenalties, + } + return false + }) + + return rewards +} + +func (self *canonicalRewardState) NextBlock() { + self.Reward = append(self.Reward, self.CalculateCanonicalReward()) + self.CurrentTimestamp += self.MsPerBlock + std.TestSkipHeights(1) + if int64(self.CurrentHeight()) != std.GetHeight() { + panic(ufmt.Sprintf("height mismatch: %d != %d", self.CurrentHeight(), std.GetHeight())) + } +} + +func (self *canonicalRewardState) NextBlockNoCanonical() { + self.Reward = append(self.Reward, nil) // just placeholder + self.CurrentTimestamp += self.MsPerBlock + std.TestSkipHeights(1) + if int64(self.CurrentHeight()) != std.GetHeight() { + panic(ufmt.Sprintf("height mismatch: %d != %d", self.CurrentHeight(), std.GetHeight())) + } +} + +func (self *canonicalRewardState) UnclaimableExternalRewardOf(depositId uint64, incentiveId string) uint64 { + pool, ok := self.global.pools.Get(self.global.deposits.Get(depositId).targetPoolPath) + if !ok { + panic("pool not found") + } + + return pool.incentives.calculateUnclaimableReward(incentiveId) +} + +func (self *canonicalRewardState) CanonicalRewardOf(depositId uint64) Reward { + return self.Reward[self.CurrentHeight()-1][depositId] +} + +func (self *canonicalRewardState) SafeCanonicalRewardOf(depositId uint64) (Reward, bool) { + rewards := self.Reward[self.CurrentHeight()-1] + reward, ok := rewards[depositId] + return reward, ok +} + +func (self *canonicalRewardState) CanonicalRewardOfHeight(depositId uint64, height uint64) Reward { + return self.Reward[height][depositId] +} + +func (self *canonicalRewardState) EmulateCalcPositionReward(tokenId uint64) ([]uint64, []uint64, []map[string]uint64, []map[string]uint64) { + currentHeight := self.CurrentHeight() + + // cache per-tier and per-pool rewards + self.global.poolTier.cacheReward(currentHeight, self.global.pools) + + deposit := self.global.deposits.Get(tokenId) + + + poolPath := deposit.targetPoolPath + + pool, ok := self.global.pools.Get(poolPath) + if !ok { + pool = NewPool(poolPath, currentHeight) + self.global.pools.Set(poolPath, pool) + } + + lastCollectHeight := deposit.lastCollectHeight + + internalRewards, internalPenalties := pool.RewardStateOf(deposit).CalculateInternalReward(int64(lastCollectHeight), int64(currentHeight)) + + externalRewards := make([]map[string]uint64, 4) + externalPenalties := make([]map[string]uint64, 4) + for i := range externalRewards { + externalRewards[i] = make(map[string]uint64) + externalPenalties[i] = make(map[string]uint64) + } + + allIncentives := pool.incentives.GetAllInHeights(int64(lastCollectHeight), int64(currentHeight)) + + for incentiveId, incentive := range allIncentives { + externalReward, externalPenalty := pool.RewardStateOf(deposit).CalculateExternalReward(int64(lastCollectHeight), int64(currentHeight), incentive) + + for i := range externalReward { + externalRewards[i][incentive.incentiveId] = externalReward[i] + externalPenalties[i][incentive.incentiveId] = externalPenalty[i] + } + } + + return internalRewards, internalPenalties, externalRewards, externalPenalties +} + +func (self *canonicalRewardState) EmulatedRewardOf(depositId uint64) Reward { + if !self.global.deposits.Has(depositId) { + claimed := self.emulatedClaimedReward[depositId] + self.emulatedClaimedReward[depositId] = 0 + return Reward{ + Internal: claimed, + InternalPenalty: 0, + External: make(map[string]uint64), + ExternalPenalty: make(map[string]uint64), + } + } + + rewards, penalties, externalRewards, externalPenalties := self.EmulateCalcPositionReward(depositId) + + deposit := self.global.deposits.Get(depositId) + deposit.lastCollectHeight = self.CurrentHeight() + + internal := uint64(0) + for _, reward := range rewards { + internal += reward + } + claimed, ok := self.emulatedClaimedReward[depositId] + if ok { + internal += claimed + self.emulatedClaimedReward[depositId] = 0 + } + internalPenalty := uint64(0) + for _, penalty := range penalties { + internalPenalty += penalty + } + external := make(map[string]uint64) + for _, er := range externalRewards { + for incentiveId, reward := range er { + external[incentiveId] += reward + } + } + externalPenalty := make(map[string]uint64) + for _, ep := range externalPenalties { + for incentiveId, penalty := range ep { + externalPenalty[incentiveId] += penalty + } + } + + return Reward{ + Internal: internal, + InternalPenalty: internalPenalty, + External: external, + ExternalPenalty: externalPenalty, + } +} + +// Emulation of staker.gno public entrypoints +func (self *canonicalRewardState) StakeToken(tokenId uint64, targetPoolPath string, owner std.Address, tickLower int32, tickUpper int32, liquidity *u256.Uint) error { + currentHeight := self.CurrentHeight() + pool, ok := self.global.pools.Get(targetPoolPath) + if !ok { + panic("should not happen 1") + } + + deposit := &Deposit{ + owner: owner, + stakeHeight: int64(currentHeight), + targetPoolPath: targetPoolPath, + tickLower: tickLower, + tickUpper: tickUpper, + liquidity: liquidity, + lastCollectHeight: currentHeight, + warmups: InstantiateWarmup(currentHeight), + } + canonicalPool, ok := self.Pool[deposit.targetPoolPath] + if !ok { + return errors.New("pool not found") + } + if canonicalPool.tier == 0 && len(canonicalPool.incentive) == 0 { + return errors.New("pool has no tier or incentive") + } + + // update global state + self.global.deposits.Set(tokenId, deposit) + + self.global.poolTier.cacheReward(currentHeight, self.global.pools) + + signedLiquidity := i256.FromUint256(deposit.liquidity) + if self.isInRange(deposit) { + pool.modifyDeposit(signedLiquidity, currentHeight, canonicalPool.tick) + } + // historical tick must be set regardless of the deposit's range + pool.historicalTick.Set(currentHeight, canonicalPool.tick) + + pool.ticks.Get(deposit.tickLower).modifyDepositLower(currentHeight, canonicalPool.tick, signedLiquidity) + pool.ticks.Get(deposit.tickUpper).modifyDepositUpper(currentHeight, canonicalPool.tick, signedLiquidity) + + return nil +} + +func (self *canonicalRewardState) UnstakeToken(tokenId uint64) { + deposit := self.global.deposits.Get(tokenId) + + currentHeight := self.CurrentHeight() + + canonicalPool, ok := self.Pool[deposit.targetPoolPath] + if !ok { + panic("should not happen 2") + } + + // Emulating CollectReward() + reward := self.EmulatedRewardOf(tokenId) + self.emulatedClaimedReward[tokenId] += reward.Internal + + // update global state + // we will not gonna actually remove the deposit in sake of logic simplicity + self.global.deposits.Remove(tokenId) + + pool, ok := self.global.pools.Get(deposit.targetPoolPath) + if !ok { + panic("should not happen 3") + } + signedLiquidity := i256.FromUint256(deposit.liquidity) + signedLiquidity = signedLiquidity.Neg(signedLiquidity) + if self.isInRange(deposit) { + pool.modifyDeposit(signedLiquidity, currentHeight, canonicalPool.tick) + } + pool.ticks.Get(deposit.tickLower).modifyDepositLower(currentHeight, canonicalPool.tick, signedLiquidity) + pool.ticks.Get(deposit.tickUpper).modifyDepositUpper(currentHeight, canonicalPool.tick, signedLiquidity) +} + +func newExternalIncentiveByHeight( + targetPoolPath string, + rewardToken string, + rewardAmount uint64, + startTimestamp int64, + endTimestamp int64, + startHeight int64, + endHeight int64, + refundee std.Address, +) *ExternalIncentive { + rewardPerBlock := rewardAmount / uint64(endHeight-startHeight) + + byTimeId := incentiveIdByTime(startTimestamp, endTimestamp, refundee, rewardToken) + + return &ExternalIncentive{ + incentiveId: byTimeId, + targetPoolPath: targetPoolPath, + rewardToken: rewardToken, + rewardAmount: rewardAmount, + startTimestamp: startTimestamp, + endTimestamp: endTimestamp, + startHeight: startHeight, + endHeight: endHeight, + rewardPerBlock: rewardPerBlock, + refundee: refundee, + createdHeight: startHeight, + depositGnsAmount: 0, + } +} + +func (self *canonicalRewardState) CreateExternalIncentive(targetPoolPath string, rewardToken string, rewardAmount uint64, startTimestamp, endTimestamp, startHeight, endHeight int64, refundee std.Address) string { + incentive := newExternalIncentiveByHeight(targetPoolPath, rewardToken, rewardAmount, startTimestamp, endTimestamp, startHeight, endHeight, refundee) + + // update canonical state + pool, ok := self.Pool[targetPoolPath] + if !ok { + self.Pool[targetPoolPath] = &canonicalPool{ + poolPath: targetPoolPath, + tier: 0, + tick: 0, + incentive: make([]*ExternalIncentive, 0), + tickCrossHook: self.tickCrossHook, + } + } + pool.incentive = append(pool.incentive, incentive) + + // update global state + self.global.pools.GetOrCreate(targetPoolPath).incentives.create(refundee, incentive) + + return incentive.incentiveId +} + +func (self *canonicalRewardState) ChangePoolTier(poolPath string, tier uint64) { + // update canonical state + pool, ok := self.Pool[poolPath] + if !ok { + pool = &canonicalPool{ + poolPath: poolPath, + tier: tier, + tick: 0, + incentive: make([]*ExternalIncentive, 0), + tickCrossHook: self.tickCrossHook, + } + self.Pool[poolPath] = pool + } + pool.tier = tier + + // update global state + if !self.global.pools.Has(poolPath) { + self.global.pools.Set(poolPath, NewPool(poolPath, self.CurrentHeight())) + } + self.global.poolTier.changeTier(self.CurrentHeight(), self.global.pools, poolPath, tier) +} + +func (self *canonicalRewardState) CreatePool(poolPath string, initialTier uint64, initialTick int32) { + self.Pool[poolPath] = &canonicalPool{ + poolPath: poolPath, + tier: initialTier, + tick: initialTick, + incentive: make([]*ExternalIncentive, 0), + tickCrossHook: self.tickCrossHook, + } + self.global.pools.Set(poolPath, NewPool(poolPath, self.CurrentHeight())) + self.global.poolTier.changeTier(self.CurrentHeight(), self.global.pools, poolPath, initialTier) +} + +func (self *canonicalRewardState) MoveTick(poolPath string, tick int32) { + pool, ok := self.Pool[poolPath] + if !ok { + panic("should not happen 4") + } + globalPool, ok := self.global.pools.Get(poolPath) + if !ok { + panic("should not happen 5") + } + + if pool.tick == tick { + return + } + + self.t.Logf(" [%d] (%d->%d) %s", self.CurrentHeight(), pool.tick, tick, pool.poolPath) + + zeroForOne := tick < pool.tick // true if moving left, false if moving right + if zeroForOne { + // backward + for i := pool.tick; i > tick; i-- { + // uninitialized tick + if !globalPool.ticks.Has(i) { + continue + } + + // update global state + pool.tickCrossHook(pool.poolPath, i, zeroForOne) + } + } else { + // forward + for i := pool.tick + 1; i <= tick; i++ { + // uninitialized tick + if !globalPool.ticks.Has(i) { + continue + } + + // update global state + pool.tickCrossHook(pool.poolPath, i, zeroForOne) + } + } + + // update canonical state + pool.tick = tick +} + +// Testing helpers + +func (self *canonicalRewardState) AssertCanonicalInternalRewardPerPool(poolPath string, expected uint64) { + internalRewardPerPool := self.InternalRewardPerPool(self.PerBlockEmission) + actual := internalRewardPerPool[poolPath] + if actual != expected { + panic(ufmt.Sprintf("internal reward per pool mismatch: expected %d, got %d", expected, actual)) + } +} + +// returns true if actual is within 0.0001% of expected +func isInErrorRange(expected uint64, actual uint64) bool { + maxSafeValue := uint64(math.MaxUint64 / 100001) + var lowerBound, upperBound uint64 + if expected > maxSafeValue { + lowerBound = expected / 1000000 * 999999 + upperBound = expected / 1000000 * 1000001 + } else { + lowerBound = expected * 999999 / 1000000 + upperBound = expected * 1000001 / 1000000 + } + return actual >= lowerBound && actual <= upperBound +} + +func (self *canonicalRewardState) AssertEmulatedRewardOf(depositId uint64, expected uint64) { + reward := self.EmulatedRewardOf(depositId) + if !isInErrorRange(expected, reward.Internal) { + self.t.Errorf("emulated reward of %d mismatch: expected %d, got %d", depositId, expected, reward.Internal) + panic("emulated reward mismatch") + } +} + +func (self *canonicalRewardState) AssertEmulatedExternalRewardOf(depositId uint64, incentiveId string, expected uint64) { + reward := self.EmulatedRewardOf(depositId) + if !isInErrorRange(expected, reward.External[incentiveId]) { + self.t.Errorf("!!!!emulated external reward of %d mismatch: expected %d, got %d", depositId, expected, reward.External[incentiveId]) + } +} + +func (self *canonicalRewardState) AssertCanonicalRewardOf(depositId uint64, expected uint64) { + reward := self.CanonicalRewardOf(depositId) + if !isInErrorRange(expected, reward.Internal) { + self.t.Errorf("canonical reward of %d mismatch: expected %d, got %d", depositId, expected, reward.Internal) + } +} + +func (self *canonicalRewardState) AssertEquivalence(depositId uint64) { + reward := self.CanonicalRewardOf(depositId) + emulatedReward := self.EmulatedRewardOf(depositId) + if !isInErrorRange(reward.Internal, emulatedReward.Internal) { + self.t.Errorf("canonical reward of %d mismatch: expected %d, got %d", depositId, reward.Internal, emulatedReward.Internal) + } +} + +func (self *canonicalRewardState) AssertEmulatedRewardMap(expected map[uint64]uint64) { + for key, value := range expected { + self.AssertEmulatedRewardOf(key, value) + } +} diff --git a/contract/r/gnoswap/staker/reward_calculation_canonical_test.gno b/contract/r/gnoswap/staker/reward_calculation_canonical_test.gno new file mode 100644 index 000000000..ade109f72 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_canonical_test.gno @@ -0,0 +1,1550 @@ +package staker + +// Evaluate against the canonical implementation + +import ( + "math" + "std" + "strings" + "testing" + + ufmt "gno.land/p/demo/ufmt" + + u256 "gno.land/p/gnoswap/uint256" +) + +func Setup(t *testing.T) *canonicalRewardState { + pools := NewPools() + deposits := NewDeposits() + + return NewCanonicalRewardState(t, pools, deposits, TickCrossHook) +} + +func TestCanonicalSimple(t *testing.T) { + canonical := Setup(t) + + gnousdc := test_gnousdc + canonical.CreatePool(gnousdc, 1, 150) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + canonical.NextBlock() + + // gnousdc takes the entire emission + canonical.AssertCanonicalInternalRewardPerPool(gnousdc, canonical.PerBlockEmission) + + expected := canonical.PerBlockEmission * 30 / 100 + canonical.AssertEmulatedRewardOf(0, expected) + canonical.AssertCanonicalRewardOf(0, expected) +} + +// To check precision error +func TestCanonicalLargeStakedLiquidity(t *testing.T) { + canonical := Setup(t) + + gnousdc := test_gnousdc + canonical.CreatePool(gnousdc, 1, 150) + + canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + q128, + ) + + canonical.NextBlock() + + canonical.AssertCanonicalInternalRewardPerPool(gnousdc, canonical.PerBlockEmission) + + expected := canonical.PerBlockEmission * 30 / 100 + canonical.AssertEmulatedRewardOf(0, expected) + canonical.AssertCanonicalRewardOf(0, expected) +} + +// To check precision error +func TestCanonicalLargeStakedLiquidity_2(t *testing.T) { + canonical := Setup(t) + + gnousdc := test_gnousdc + canonical.CreatePool(gnousdc, 1, 150) + + u2_30 := uint64(1073741824) + u2_33 := uint64(8589934592) + + // Estimated per-block emission for staker is estimated to not get more than 100 million, but we are stress testing. + // If more than 100 million is emitted, the overflow could occur inside of the PoolTier distribution logic. + canonical.SetEmissionUpdate(10000000000000) + + canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(u2_30), + ) + + canonical.StakeToken( + 1, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + 100, + 200, + u256.NewUint(u2_33), + ) + + canonical.NextBlock() + + canonical.AssertCanonicalInternalRewardPerPool(gnousdc, canonical.PerBlockEmission) + + expected := canonical.PerBlockEmission / 100 * 30 + canonical.AssertEmulatedRewardOf(0, expected/9) + canonical.AssertCanonicalRewardOf(0, expected/9) + + canonical.AssertEmulatedRewardOf(1, expected/9*8) + canonical.AssertCanonicalRewardOf(1, expected/9*8) +} + +// Tests simple case with tick crossing +func TestCanonicalTickCross_0(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 150) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + // The position remains inrange, no change in reward. + canonical.MoveTick(gnousdc, 120) + + canonical.NextBlock() + + // gnousdc takes the entire emission + canonical.AssertCanonicalInternalRewardPerPool(gnousdc, canonical.PerBlockEmission) + + expected := canonical.PerBlockEmission * 30 / 100 + canonical.AssertEmulatedRewardOf(0, expected) + canonical.AssertCanonicalRewardOf(0, expected) + + // The position is now outrange. The reward must be reduced. + canonical.MoveTick(gnousdc, 90) + + canonical.NextBlock() + + // gnousdc still takes the entire emission, independent from position inrangeness + canonical.AssertCanonicalInternalRewardPerPool(gnousdc, canonical.PerBlockEmission) + + canonical.AssertEmulatedRewardOf(0, 0) + canonical.AssertCanonicalRewardOf(0, 0) + + // The position remains outrange. + canonical.MoveTick(gnousdc, 80) + + canonical.NextBlock() + + // gnousdc still takes the entire emission, independent from position inrangeness + canonical.AssertCanonicalInternalRewardPerPool(gnousdc, canonical.PerBlockEmission) + + canonical.AssertEmulatedRewardOf(0, 0) + canonical.AssertCanonicalRewardOf(0, 0) + + // The tick passes through the positions range, remains outrange. + canonical.MoveTick(gnousdc, 220) + + canonical.NextBlock() + + // gnousdc still takes the entire emission, independent from position inrangeness + canonical.AssertCanonicalInternalRewardPerPool(gnousdc, canonical.PerBlockEmission) + + canonical.AssertEmulatedRewardOf(0, 0) + canonical.AssertCanonicalRewardOf(0, 0) + + // The tick goes back to the range. + canonical.MoveTick(gnousdc, 180) + + canonical.NextBlock() + + // gnousdc still takes the entire emission, independent from position inrangeness + canonical.AssertCanonicalInternalRewardPerPool(gnousdc, canonical.PerBlockEmission) + + // The position is now inrange, so it takes the entire emission + canonical.AssertEmulatedRewardOf(0, expected) + canonical.AssertCanonicalRewardOf(0, expected) +} + +// Tests tick crossing with lazy evaluation of position reward +func TestCanonicalTickCross_1(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 150) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + // The position remains inrange, no change in reward. + canonical.MoveTick(gnousdc, 120) + + canonical.NextBlockNoCanonical() + + // The position is now outrange. The reward must be reduced. + canonical.MoveTick(gnousdc, 90) + + canonical.NextBlockNoCanonical() + + // The position remains outrange. + canonical.MoveTick(gnousdc, 80) + + canonical.NextBlockNoCanonical() + + // The tick passes through the positions range, remains outrange. + canonical.MoveTick(gnousdc, 220) + + canonical.NextBlockNoCanonical() + + // The tick goes back to the range. + canonical.MoveTick(gnousdc, 180) + + canonical.NextBlockNoCanonical() + + // We check that emulated reward is lazy evaluated when we finally calculate it. + // It takes the two block's emission into account. + expected := canonical.PerBlockEmission * 30 / 100 + canonical.AssertEmulatedRewardOf(0, expected*2) +} + +// Test tick crossing with multiple positions with same tick, same liquidity +func TestCanonicalTickCross_2(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 150) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 1, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + // eligible + canonical.MoveTick(gnousdc, 120) + canonical.NextBlockNoCanonical() + + // not eligible + canonical.MoveTick(gnousdc, 90) + canonical.NextBlockNoCanonical() + + // not eligible + canonical.MoveTick(gnousdc, 80) + canonical.NextBlockNoCanonical() + + // not eligible + canonical.MoveTick(gnousdc, 220) + canonical.NextBlockNoCanonical() + + // eligible + canonical.MoveTick(gnousdc, 180) + canonical.NextBlockNoCanonical() + + // eligible + canonical.MoveTick(gnousdc, 120) + canonical.NextBlockNoCanonical() + + expected := canonical.PerBlockEmission * 30 / 100 + canonical.AssertEmulatedRewardOf(0, expected*3/2) + canonical.AssertEmulatedRewardOf(1, expected*3/2) +} + +// Test tick crossing with multiple positions with same tick, different liquidity +func TestCanonicalTickCross_3(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 150) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 1, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + 100, + 200, + u256.NewUint(3000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + // eligible + canonical.MoveTick(gnousdc, 120) + canonical.NextBlockNoCanonical() + + // not eligible + canonical.MoveTick(gnousdc, 90) + canonical.NextBlockNoCanonical() + + // not eligible + canonical.MoveTick(gnousdc, 80) + canonical.NextBlockNoCanonical() + + // not eligible + canonical.MoveTick(gnousdc, 220) + canonical.NextBlockNoCanonical() + + // eligible + canonical.MoveTick(gnousdc, 180) + canonical.NextBlockNoCanonical() + + // not eligible + canonical.MoveTick(gnousdc, 80) + canonical.NextBlockNoCanonical() + + expected := canonical.PerBlockEmission * 30 / 100 + canonical.AssertEmulatedRewardOf(0, expected*2*1/4) + canonical.AssertEmulatedRewardOf(1, expected*2*3/4) +} + +// Test tick crossing with multiple positions with different tick, same liquidity +func TestCanonicalTickCross_4(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 1, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + 200, + 400, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + expected := canonical.PerBlockEmission * 30 / 100 + + // 0 eligible, 1 not eligible (100:0) + canonical.MoveTick(gnousdc, 101) + canonical.NextBlock() + canonical.AssertCanonicalRewardOf(0, expected) + canonical.AssertCanonicalRewardOf(1, 0) + // 0 eligible, 1 not eligible (100:0) + canonical.MoveTick(gnousdc, 101) + canonical.NextBlock() + canonical.AssertCanonicalRewardOf(0, expected) + canonical.AssertCanonicalRewardOf(1, 0) + + // 0 eligible, 1 eligible (50:50) + canonical.MoveTick(gnousdc, 201) + canonical.NextBlock() + canonical.AssertCanonicalRewardOf(0, expected/2) + canonical.AssertCanonicalRewardOf(1, expected/2) + + // 0 not eligible, 1 not eligible (0:0) + canonical.MoveTick(gnousdc, 401) + canonical.NextBlock() + canonical.AssertCanonicalRewardOf(0, 0) + canonical.AssertCanonicalRewardOf(1, 0) + + // 0 not eligible, 1 eligible (0:100) + canonical.MoveTick(gnousdc, 301) + canonical.NextBlock() + canonical.AssertCanonicalRewardOf(0, 0) + canonical.AssertCanonicalRewardOf(1, expected) + + // 0 total ratio: 250 + // 1 total ratio: 150 + + canonical.AssertEmulatedRewardOf(0, expected*4*250/400) + canonical.AssertEmulatedRewardOf(1, expected*4*150/400) +} + +// Test tick crossing at tick boundaries, forward direction +func TestCanonicalTickCross_5(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + expected := canonical.PerBlockEmission * 30 / 100 + + // not eligible + // block 1 + canonical.MoveTick(gnousdc, 99) + canonical.NextBlock() + + // block 2 + canonical.AssertCanonicalRewardOf(0, 0) + canonical.AssertEmulatedRewardOf(0, 0) + + // eligible + // entered range + canonical.MoveTick(gnousdc, 100) + canonical.NextBlock() + + // block 3 + canonical.AssertCanonicalRewardOf(0, expected) + canonical.AssertEmulatedRewardOf(0, expected) + + // eligible + canonical.MoveTick(gnousdc, 101) + canonical.NextBlock() + + // block 4 + canonical.AssertCanonicalRewardOf(0, expected) + canonical.AssertEmulatedRewardOf(0, expected) + + // eligible + canonical.MoveTick(gnousdc, 299) + canonical.NextBlock() + + canonical.AssertCanonicalRewardOf(0, expected) + canonical.AssertEmulatedRewardOf(0, expected) + + // not eligible + canonical.MoveTick(gnousdc, 300) + canonical.NextBlock() + + canonical.AssertCanonicalRewardOf(0, 0) + canonical.AssertEmulatedRewardOf(0, 0) + + // not eligible + canonical.MoveTick(gnousdc, 301) + canonical.NextBlock() + + canonical.AssertCanonicalRewardOf(0, 0) + canonical.AssertEmulatedRewardOf(0, 0) +} + +// Test tick crossing at tick boundaries, backward direction +func TestCanonicalTickCross_6(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + expected := canonical.PerBlockEmission * 30 / 100 + + // not eligible + canonical.MoveTick(gnousdc, 301) + canonical.NextBlockNoCanonical() + + canonical.AssertEmulatedRewardOf(0, 0) + + // not eligible + canonical.MoveTick(gnousdc, 300) + canonical.NextBlockNoCanonical() + + canonical.AssertEmulatedRewardOf(0, 0) + + // eligible + canonical.MoveTick(gnousdc, 299) + canonical.NextBlockNoCanonical() + + canonical.AssertEmulatedRewardOf(0, expected) + + // eligible + canonical.MoveTick(gnousdc, 101) + canonical.NextBlockNoCanonical() + + canonical.AssertEmulatedRewardOf(0, expected) + + // eligible + canonical.MoveTick(gnousdc, 100) + canonical.NextBlockNoCanonical() + + canonical.AssertEmulatedRewardOf(0, expected) + + // not eligible + canonical.MoveTick(gnousdc, 99) + canonical.NextBlockNoCanonical() + + canonical.AssertEmulatedRewardOf(0, 0) +} + +func TestCanonicalExternalReward_1(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + currentHeight := int64(canonical.CurrentHeight()) + incentiveId := canonical.CreateExternalIncentive(gnousdc, gnsPath, 100000000, currentHeight+1, currentHeight+5, currentHeight+1, currentHeight+5, std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp")) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgg"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + canonical.NextBlock() + + // Incentive has been started, but but we can collect rewards accumulated until the previous block + + canonical.NextBlock() + + // now there is a single block worth of reward + + expected := uint64(100000000/4) * 30 / 100 + + canonical.AssertEmulatedExternalRewardOf(uint64(0), incentiveId, expected) +} + + +func TestCanonicalExternalReward_2(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + currentHeight := int64(canonical.CurrentHeight()) + incentiveId := canonical.CreateExternalIncentive(gnousdc, gnsPath, 100000000, currentHeight+1, currentHeight+5, currentHeight+1, currentHeight+5, std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp")) + + canonical.NextBlock() + canonical.NextBlock() + + // Incentive has been already started + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgg"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + canonical.NextBlock() + + // now there is a single block worth of reward + + expected := uint64(100000000/4) * 30 / 100 + + canonical.AssertEmulatedExternalRewardOf(uint64(0), incentiveId, expected) +} + +func TestCanonicalExternalReward_3(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + currentHeight := int64(canonical.CurrentHeight()) + incentiveId := canonical.CreateExternalIncentive(gnousdc, gnsPath, 100000000, currentHeight+1, currentHeight+5, currentHeight+1, currentHeight+5, std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp")) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgg"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + canonical.NextBlock() + + // eligible + + canonical.NextBlock() + + // eligible + + canonical.NextBlock() + + // not eligible + + canonical.MoveTick(gnousdc, 400) + + // not eligible + + canonical.NextBlock() + + // incentive has been ended + + expected := uint64(100000000/4*2) * 30 / 100 + + canonical.AssertEmulatedExternalRewardOf(uint64(0), incentiveId, expected) + + if canonical.UnclaimableExternalRewardOf(uint64(0), incentiveId) != uint64(100000000/4*2) { + t.Errorf("UnclaimableExternalRewardOf(uint64(0), incentiveId) = %d; want %d", canonical.UnclaimableExternalRewardOf(uint64(0), incentiveId), uint64(100000000/4*2)) + } +} + + +func TestCanonicalMultiplePool_1(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + gnousdc2 := GetPoolPath(wugnotPath, gnsPath, 10000) + canonical.CreatePool(gnousdc, 1, 200) + canonical.CreatePool(gnousdc2, 1, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 1, + gnousdc2, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + 200, + 400, + u256.NewUint(2000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + canonical.NextBlock() + + expected := canonical.PerBlockEmission * 30 / 100 + + canonical.AssertEmulatedRewardOf(0, expected/2) + canonical.AssertEmulatedRewardOf(1, expected/2) + + canonical.MoveTick(gnousdc, 100) + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected/2) + canonical.AssertEmulatedRewardOf(1, expected/2) + + canonical.MoveTick(gnousdc2, 100) + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected/2) + canonical.AssertEmulatedRewardOf(1, 0) + + canonical.MoveTick(gnousdc, 100) + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected/2) + canonical.AssertEmulatedRewardOf(1, 0) + + canonical.MoveTick(gnousdc, 300) + canonical.MoveTick(gnousdc2, 300) + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, 0) + canonical.AssertEmulatedRewardOf(1, expected/2) +} + +func TestCanonicalMultiplePool_2(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + gnousdc2 := GetPoolPath(wugnotPath, gnsPath, 10000) + gnousdc3 := GetPoolPath(wugnotPath, gnsPath, 30000) + canonical.CreatePool(gnousdc, 1, 200) + canonical.CreatePool(gnousdc2, 2, 200) + canonical.CreatePool(gnousdc3, 3, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 1, + gnousdc2, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + 200, + 400, + u256.NewUint(2000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 2, + gnousdc3, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgz"), + 300, + 500, + u256.NewUint(3000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + expected := canonical.PerBlockEmission * 30 / 100 + + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected*50/100) + canonical.AssertEmulatedRewardOf(1, expected*30/100) + canonical.AssertEmulatedRewardOf(2, 0) + canonical.MoveTick(gnousdc, 100) + canonical.MoveTick(gnousdc2, 100) + canonical.MoveTick(gnousdc3, 100) + + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected*50/100) + canonical.AssertEmulatedRewardOf(1, 0) + canonical.AssertEmulatedRewardOf(2, 0) + canonical.MoveTick(gnousdc, 400) + canonical.MoveTick(gnousdc2, 400) + canonical.MoveTick(gnousdc3, 400) + + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, 0) + canonical.AssertEmulatedRewardOf(1, 0) + canonical.AssertEmulatedRewardOf(2, expected*20/100) + canonical.ChangePoolTier(gnousdc2, 1) + canonical.MoveTick(gnousdc, 200) + canonical.MoveTick(gnousdc2, 300) + canonical.MoveTick(gnousdc3, 400) + + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected*40/100) + canonical.AssertEmulatedRewardOf(1, expected*40/100) + canonical.AssertEmulatedRewardOf(2, expected*20/100) + +} + +// Large number of blocks passed +func TestCanonicalLargeBlocksPassed(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 150) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + for i := 0; i < 10000; i++ { + canonical.NextBlockNoCanonical() + } + + expected := canonical.PerBlockEmission * 30 / 100 + canonical.AssertEmulatedRewardOf(0, expected*10000) +} + +func GetPoolPath(token0Path, token1Path string, fee uint32) string { + if strings.Compare(token1Path, token0Path) < 0 { + token0Path, token1Path = token1Path, token0Path + } + return ufmt.Sprintf("%s:%s:%d", token0Path, token1Path, fee) +} + +func TestCanonicalWarmup_1(t *testing.T) { + modifyWarmup(0, 10) + modifyWarmup(1, 10) + modifyWarmup(2, 10) + + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 150) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + expected := canonical.PerBlockEmission * 30 / 100 + + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected*10) + + expected = canonical.PerBlockEmission * 50 / 100 + + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected*10) + + expected = canonical.PerBlockEmission * 70 / 100 + + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected*10) + + expected = canonical.PerBlockEmission * 100 / 100 + + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.AssertEmulatedRewardOf(0, expected*10) + +} + +func TestCanonicalWarmup_2(t *testing.T) { + modifyWarmup(0, 10) + modifyWarmup(1, 10) + modifyWarmup(2, 10) + + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 150) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 200, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + expected0 := canonical.PerBlockEmission * 30 / 100 + + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + + expected1 := canonical.PerBlockEmission * 50 / 100 + + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + + expected2 := canonical.PerBlockEmission * 70 / 100 + + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + + expected3 := canonical.PerBlockEmission * 100 / 100 + + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + canonical.NextBlock() + + canonical.AssertEmulatedRewardOf(0, expected0*10+expected1*10+expected2*10+expected3*10) + +} + +// ================ +// randomized + +// Test tick crossing at tick boundaries, random direction +func TestCanonicalTickCross_7(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + expected := canonical.PerBlockEmission * 30 / 100 + + tcs := []struct { + tick int32 + expected uint64 + }{ + {99, 0}, + {100, expected}, + {101, expected}, + {299, expected}, + {300, 0}, + {301, 0}, + } + + index := 0 + for i := 0; i < 100; i++ { + index = (index + 10007) % len(tcs) + canonical.MoveTick(gnousdc, tcs[index].tick) + canonical.NextBlockNoCanonical() + canonical.AssertEmulatedRewardOf(0, tcs[index].expected) + } +} + +// Test tick crossing with multiple positions with different tick, different liquidity +// Equivalence test +func TestCanonicalTickCross_8(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 1, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + 200, + 400, + u256.NewUint(2000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 2, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgz"), + 300, + 500, + u256.NewUint(3000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + tick := int32(0) + + for i := 0; i < 500; i++ { + tick = (tick + 10007) % 700 + canonical.MoveTick(gnousdc, tick) + canonical.NextBlock() + canonical.AssertEquivalence(0) + canonical.AssertEquivalence(1) + canonical.AssertEquivalence(2) + } +} + +// Test tick crossing with multiple positions with different tick, different liquidity, emulated reward flushed for every 100 blocks +// Equivalence test +func TestCanonicalTickCross_9(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 1, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + 200, + 400, + u256.NewUint(2000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 2, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgz"), + 300, + 500, + u256.NewUint(3000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + tick := int32(0) + for i := 0; i < 20; i++ { + canonicalRewardMap := make(map[uint64]uint64) + for j := 0; j < 20; j++ { + tick = (tick + 10007) % 700 + canonical.MoveTick(gnousdc, tick) + canonical.NextBlock() + canonicalRewardMap[0] += canonical.CanonicalRewardOf(0).Internal + canonicalRewardMap[1] += canonical.CanonicalRewardOf(1).Internal + canonicalRewardMap[2] += canonical.CanonicalRewardOf(2).Internal + } + canonical.AssertEmulatedRewardMap(canonicalRewardMap) + } +} + + +// Test tick crossing with multiple positions with different tick, different liquidity, emulated reward flushed for every 100 blocks +// Equivalence test +func TestCanonicalTickCross_10(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, -200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + -300, + -100, + u256.NewUint(1000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 1, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgq"), + -400, + -200, + u256.NewUint(2000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + err = canonical.StakeToken( + 2, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgz"), + -500, + -300, + u256.NewUint(3000000000000000000), + ) + + if err != nil { + t.Errorf("StakeToken failed: %s", err.Error()) + } + + tick := int32(0) + for i := 0; i < 20; i++ { + canonicalRewardMap := make(map[uint64]uint64) + for j := 0; j < 20; j++ { + tick = (tick + 10007) % 700 + canonical.MoveTick(gnousdc, -tick) + canonical.NextBlock() + canonicalRewardMap[0] += canonical.CanonicalRewardOf(0).Internal + canonicalRewardMap[1] += canonical.CanonicalRewardOf(1).Internal + canonicalRewardMap[2] += canonical.CanonicalRewardOf(2).Internal + } + canonical.AssertEmulatedRewardMap(canonicalRewardMap) + } +} + +func TestCanonicalEmissionChange_0(t *testing.T) { + canonical := Setup(t) + + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + err := canonical.StakeToken( + 0, + gnousdc, + std.Address("gno1qyqszqgpqyqszqgpqyqszqgpqyqszqgp"), + 100, + 300, + u256.NewUint(1000000000000000000), + ) + + canonical.NextBlock() + + expected0 := canonical.PerBlockEmission * 30 / 100 + + t.Logf("expected0: %d", expected0) + + canonical.SetEmissionUpdate(canonical.PerBlockEmission * 10 / 100) + + canonical.NextBlock() + + expected1 := canonical.PerBlockEmission * 30 / 100 + + t.Logf("expected1: %d", expected1) + + canonical.AssertEmulatedRewardOf(0, expected0+expected1) +} + +type TestEventCreatePool struct { + poolPath string + initialTier uint64 + initialTick int32 +} + +func (event *TestEventCreatePool) Apply(canonical *canonicalRewardState) { + canonical.CreatePool(event.poolPath, event.initialTier, event.initialTick) +} + +func (event *TestEventCreatePool) IsValid(canonical *canonicalRewardState) bool { + _, ok := canonical.global.pools.Get(event.poolPath) + return !ok +} + +func (event *TestEventCreatePool) String() string { + return ufmt.Sprintf("CreatePool(%s, %d, %d)", event.poolPath, event.initialTier, event.initialTick) +} + +type TestEventChangeTier struct { + poolPath string + targetTier uint64 +} + +func (event *TestEventChangeTier) IsValid(canonical *canonicalRewardState) bool { + _, ok := canonical.global.pools.Get(event.poolPath) + return ok +} + +func (event *TestEventChangeTier) Apply(canonical *canonicalRewardState) { + canonical.ChangePoolTier(event.poolPath, event.targetTier) +} + +func (event *TestEventChangeTier) String() string { + return ufmt.Sprintf("ChangeTier(%s, %d)", event.poolPath, event.targetTier) +} + +type TestEventStakeToken struct { + tokenId uint64 + poolPath string + address std.Address + liquidity *u256.Uint + tickLower int32 + tickUpper int32 +} + +func (event *TestEventStakeToken) Apply(canonical *canonicalRewardState) { + if canonical.global.deposits.Has(event.tokenId) { + if canonical.global.deposits.Get(event.tokenId).liquidity.IsZero() { + canonical.StakeToken(event.tokenId, event.poolPath, event.address, event.tickLower, event.tickUpper, event.liquidity) + } else { + canonical.UnstakeToken(event.tokenId) + } + } else { + canonical.StakeToken(event.tokenId, event.poolPath, event.address, event.tickLower, event.tickUpper, event.liquidity) + } +} + +func (event *TestEventStakeToken) IsValid(canonical *canonicalRewardState) bool { + _, ok := canonical.global.pools.Get(event.poolPath) + return ok +} + +func (event *TestEventStakeToken) String() string { + return ufmt.Sprintf("StakeToken(%d, %s, %s, %d, %d, %s)", event.tokenId, event.poolPath, event.address, event.tickLower, event.tickUpper, event.liquidity.ToString()) +} + +type TestEventMoveTick struct { + poolPath string + tick int32 +} + +func (event *TestEventMoveTick) Apply(canonical *canonicalRewardState) { + canonical.MoveTick(event.poolPath, event.tick) +} + +func (event *TestEventMoveTick) IsValid(canonical *canonicalRewardState) bool { + _, ok := canonical.global.pools.Get(event.poolPath) + return ok +} + +func (event *TestEventMoveTick) String() string { + return ufmt.Sprintf("MoveTick(%s, %d)", event.poolPath, event.tick) +} + +type TestEventSetEmissionUpdate struct { + emission uint64 +} + +func (event *TestEventSetEmissionUpdate) Apply(canonical *canonicalRewardState) { + canonical.SetEmissionUpdate(event.emission) +} + +func (event *TestEventSetEmissionUpdate) IsValid(canonical *canonicalRewardState) bool { + return true +} + +func (event *TestEventSetEmissionUpdate) String() string { + return ufmt.Sprintf("SetEmissionUpdate(%d)", event.emission) +} + +type SimulationEvent interface { + IsValid(canonical *canonicalRewardState) bool + Apply(canonical *canonicalRewardState) + String() string +} + +type tokenStake struct { + tokenId uint64 + address std.Address + liquidity *u256.Uint + tickLower int32 + tickUpper int32 +} + +func events( + poolPath string, + initialTier uint64, + initialTick int32, + tokenStakes []tokenStake, + moveTicks []int32, +) []SimulationEvent { + result := []SimulationEvent{ + &TestEventCreatePool{ + poolPath: poolPath, + initialTier: initialTier, + initialTick: initialTick, + }, + } + for i := 0; i < len(tokenStakes); i++ { + result = append(result, &TestEventStakeToken{ + tokenId: tokenStakes[i].tokenId, + poolPath: poolPath, + address: tokenStakes[i].address, + liquidity: tokenStakes[i].liquidity, + tickLower: tokenStakes[i].tickLower, + tickUpper: tokenStakes[i].tickUpper, + }) + } + for i := 0; i < len(moveTicks); i++ { + result = append(result, &TestEventMoveTick{ + poolPath: poolPath, + tick: moveTicks[i], + }) + } + if poolPath == "gno.land/r/demo/wugnot:gno.land/r/gnoswap/v1/gns:3000" { + return result + } + for i := 0; i < 4; i++ { + result = append(result, &TestEventChangeTier{ + poolPath: poolPath, + targetTier: uint64(i), + }) + } + return result +} +func TestCanonicalSimulation_0(t *testing.T) { + wugnotgns := events( + "gno.land/r/demo/wugnot:gno.land/r/gnoswap/v1/gns:3000", + 1, + 200, + []tokenStake{ + {0, std.Address("gno1token0"), u256.NewUint(1000000000000000000), 100, 300}, + {1, std.Address("gno1token1"), u256.NewUint(2000000000000000000), 200, 400}, + {2, std.Address("gno1token2"), u256.NewUint(3000000000000000000), 300, 500}, + }, + []int32{50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750}, + ) + + barbaz := events( + "gno.land/r/demo/bar:gno.land/r/gnoswap/v1/baz:3000", + 0, + 200, + []tokenStake{ + {3, std.Address("gno1token3"), u256.NewUint(4000000000000000000), 50, 100}, + {4, std.Address("gno1token4"), u256.NewUint(5000000000000000000), 100, 200}, + {5, std.Address("gno1token5"), u256.NewUint(6000000000000000000), 300, 500}, + }, + []int32{50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750}, + ) + + bazquux := events( + "gno.land/r/demo/baz:gno.land/r/gnoswap/v1/quux:3000", + 0, + 200, + []tokenStake{ + {6, std.Address("gno1token6"), u256.NewUint(7000000000000000000), 400, 600}, + {7, std.Address("gno1token7"), u256.NewUint(8000000000000000000), 600, 800}, + {8, std.Address("gno1token8"), u256.NewUint(9000000000000000000), 700, 900}, + }, + []int32{50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750}, + ) + + quuxgns := events( + "gno.land/r/demo/quux:gno.land/r/gnoswap/v1/gns:3000", + 0, + 200, + []tokenStake{ + {9, std.Address("gno1token9"), u256.NewUint(1000000000000000000), 50, 100}, + {10, std.Address("gno1token10"), u256.NewUint(2000000000000000000), 100, 200}, + {11, std.Address("gno1token11"), u256.NewUint(3000000000000000000), 200, 300}, + }, + []int32{50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750}, + ) + + events := []SimulationEvent{} + events = append(events, wugnotgns...) + events = append(events, barbaz...) + events = append(events, bazquux...) + events = append(events, quuxgns...) + + canonical := Setup(t) + + // required to match with the default tier 1 + gnousdc := GetPoolPath(wugnotPath, gnsPath, 3000) + canonical.CreatePool(gnousdc, 1, 200) + + eventId := 0 + + for i := 0; i < 10; i++ { + println("=======ITERATION : ", i, "========") + + canonicalRewardMap := make(map[uint64]uint64) + for j := 0; j < 100; j++ { + var event SimulationEvent + for { + eventId = (eventId + 17) % len(events) + event = events[eventId] + if event.IsValid(canonical) { + break + } + } + println("=======HEIGHT : ", std.GetHeight(), "========") + println(" ", eventId, " : ", event.String()) + event.Apply(canonical) + canonical.NextBlock() + for i := 0; i < 12; i++ { + reward, ok := canonical.SafeCanonicalRewardOf(uint64(i)) + if !ok { + continue // no reward for this deposit + } + canonicalRewardMap[uint64(i)] += reward.Internal + } + } + + println("======ASSERTING========") + for k, v := range canonicalRewardMap { + println(" ", k, " : ", v) + } + canonical.AssertEmulatedRewardMap(canonicalRewardMap) + } +} \ No newline at end of file diff --git a/contract/r/gnoswap/staker/reward_calculation_incentives.gno b/contract/r/gnoswap/staker/reward_calculation_incentives.gno new file mode 100644 index 000000000..5dded0ad4 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_incentives.gno @@ -0,0 +1,188 @@ +package staker + +import ( + "std" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" +) + +// Incentives is a collection of external incentives for a given pool. +// +// Fields: +// - byTime: ExternalIncentive primarily indexed by startTime +// - byHeight: ExternalIncentive primarily indexed by startHeight +// - byEndHeight: ExternalIncentive primarily indexed by endHeight +// - byCreator: ExternalIncentive primarily indexed by creator +// +// - unclaimablePeriods: +// For each unclaimable period(start, end) for this pool, +// it stores (key: start) => (value: end) +// if end is 0, it means the unclaimable period is ongoing. +type Incentives struct { + byTime *avl.Tree // (startTime, endTime, creator, rewardToken) => ExternalIncentive + byHeight *avl.Tree // (startHeight, endHeight, creator, rewardToken) => ExternalIncentive + byEndHeight *avl.Tree // (endHeight, startHeight, creator, rewardToken) => ExternalIncentive + byCreator *avl.Tree // (creator, startHeight, endHeight, rewardToken) => ExternalIncentive + + unclaimablePeriods *UintTree // startHeight => endHeight +} + +// NewIncentives creates a new Incentives instance. +func NewIncentives() Incentives { + result := Incentives{ + byTime: avl.NewTree(), + byHeight: avl.NewTree(), + byEndHeight: avl.NewTree(), + byCreator: avl.NewTree(), + unclaimablePeriods: NewUintTree(), + } + + // initial unclaimable period starts, as there cannot be any staked positions yet. + result.unclaimablePeriods.Set(std.GetHeight(), int64(0)) + return result +} + +// Get incentive by time +func (self *Incentives) Get(startTime, endTime int64, creator std.Address, rewardToken string) (*ExternalIncentive, bool) { + byTimeId := incentiveIdByTime(startTime, endTime, creator, rewardToken) + value, ok := self.byTime.Get(byTimeId) + if !ok { + return nil, false + } + return value.(*ExternalIncentive), true +} + +// Get incentive by full incentiveId(by time) +func (self *Incentives) GetByIncentiveId(incentiveId string) (*ExternalIncentive, bool) { + value, ok := self.byTime.Get(incentiveId) + if !ok { + return nil, false + } + return value.(*ExternalIncentive), true +} + +// Get all incentives that is active in given [startHeight, endHeight) +func (self *Incentives) GetAllInHeights(startHeight, endHeight int64) map[string]*ExternalIncentive { + incentives := make(map[string]*ExternalIncentive) + // Iterate all incentives that has start height less than endHeight + self.byHeight.ReverseIterate( + "", + EncodeUint(uint64(endHeight)), + func(key string, value interface{}) bool { + incentive := value.(*ExternalIncentive) + if incentive.endHeight < startHeight { + // incentive is already ended + return false + } + + incentives[incentive.incentiveId] = incentive + return false + }, + ) + // Iterate all incentives that has end height greater than startHeight + self.byEndHeight.Iterate( + EncodeUint(uint64(startHeight)), + "", + func(key string, value interface{}) bool { + incentive := value.(*ExternalIncentive) + if incentive.startHeight > endHeight { + // incentive is not started yet + return false + } + + incentives[incentive.incentiveId] = incentive + return false + }, + ) + return incentives +} + +// Create a new external incentive +// Panics if the incentive already exists. +// Sets to byTime, byHeight, byEndHeight, byCreator +func (self *Incentives) create( + creator std.Address, + incentive *ExternalIncentive, +) { + byTimeId := incentiveIdByTime(incentive.startTimestamp, incentive.endTimestamp, creator, incentive.rewardToken) + if self.byTime.Has(byTimeId) { + panic(addDetailToError( + errIncentiveAlreadyExists, + ufmt.Sprintf("staker.gno__addExternalIncentive() || incentiveId(%s) already exists", byTimeId), + )) + } + + byHeightId, byCreatorId := incentiveIdByHeight(incentive.startHeight, incentive.endHeight, creator, incentive.rewardToken) + byEndHeightId, _ := incentiveIdByHeight(incentive.endHeight, incentive.startHeight, creator, incentive.rewardToken) + + self.byTime.Set(byTimeId, incentive) + self.byHeight.Set(byHeightId, incentive) + self.byEndHeight.Set(byEndHeightId, incentive) + self.byCreator.Set(byCreatorId, incentive) +} + +// starts incentive unclaimable period for this pool +func (self *Incentives) startUnclaimablePeriod(startHeight int64) { + self.unclaimablePeriods.Set(startHeight, int64(0)) +} + +// ends incentive unclaimable period for this pool +// ignores if currently not in unclaimable period +func (self *Incentives) endUnclaimablePeriod(endHeight int64) { + startHeight := int64(0) + self.unclaimablePeriods.ReverseIterate(0, endHeight, func(key int64, value interface{}) bool { + if value.(int64) != 0 { + // Already ended, no need to update + // keeping startHeight as 0 to indicate this + return true + } + startHeight = key + return true + }) + + if startHeight == 0 { + // No ongoing unclaimable period found + return + } + + if startHeight == endHeight { + self.unclaimablePeriods.Remove(startHeight) + } else { + self.unclaimablePeriods.Set(startHeight, endHeight) + } +} + +// calculate unclaimable reward by checking unclaimable periods +func (self *Incentives) calculateUnclaimableReward(incentiveId string) uint64 { + incentive, ok := self.GetByIncentiveId(incentiveId) + if !ok { + return 0 + } + + blocks := int64(0) + + self.unclaimablePeriods.ReverseIterate(0, incentive.startHeight, func(key int64, value interface{}) bool { + endHeight := value.(int64) + if endHeight == 0 { + endHeight = incentive.endHeight + } + if endHeight <= incentive.startHeight { + return true + } + blocks += endHeight - incentive.startHeight + return true + }) + + self.unclaimablePeriods.Iterate(incentive.startHeight, incentive.endHeight, func(key int64, value interface{}) bool { + startHeight := key + endHeight := value.(int64) + if endHeight == 0 { + endHeight = incentive.endHeight + } + blocks += endHeight - startHeight + return false + }) + + return uint64(blocks) * incentive.rewardPerBlock +} diff --git a/contract/r/gnoswap/staker/reward_calculation_pool.gno b/contract/r/gnoswap/staker/reward_calculation_pool.gno new file mode 100644 index 000000000..cbf48c58c --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_pool.gno @@ -0,0 +1,475 @@ +package staker + +import ( + "std" + + "gno.land/p/demo/avl" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/gnoswap/consts" +) + +var ( + // Q128 is 2^128 + q128 = u256.MustFromDecimal(consts.Q128) + // Q192 is 2^192 + q192 = u256.MustFromDecimal("6277101735386680763835789423207666416102355444464034512895") + + // pools is the global pool storage + pools *Pools +) + +func init() { + pools = NewPools() +} + +// Pools represents the global pool storage +type Pools struct { + tree *avl.Tree // string poolPath -> pool +} + +func NewPools() *Pools { + return &Pools{ + tree: avl.NewTree(), + } +} + +// Get returns the pool for the given poolPath +func (self *Pools) Get(poolPath string) (*Pool, bool) { + v, ok := self.tree.Get(poolPath) + if !ok { + return nil, false + } + return v.(*Pool), true +} + +// GetOrCreate returns the pool for the given poolPath, or creates a new pool if it does not exist +func (self *Pools) GetOrCreate(poolPath string) *Pool { + pool, ok := self.Get(poolPath) + if !ok { + pool = NewPool(poolPath, std.GetHeight()) + self.Set(poolPath, pool) + } + return pool +} + +// Set sets the pool for the given poolPath +func (self *Pools) Set(poolPath string, pool *Pool) { + self.tree.Set(poolPath, pool) +} + +// Has returns true if the pool exists for the given poolPath +func (self *Pools) Has(poolPath string) bool { + return self.tree.Has(poolPath) +} + +func (self *Pools) IterateAll(fn func(key string, pool *Pool) bool) { + self.tree.Iterate("", "", func(key string, value interface{}) bool { + return fn(key, value.(*Pool)) + }) +} + +// Pool is a struct for storing an incentivized pool information +// Each pool stores Incentives and Ticks associated with it. +// +// Fields: +// - poolPath: The path of the pool. +// +// - currentStakedLiquidity: +// The current total staked liquidity of the in-range positions for the pool. +// Updated when tick cross happens or stake/unstake happens. +// Used to calculate the global reward ratio accumulation or +// decide whether to enter/exit unclaimable period. +// +// - lastUnclaimableHeight: +// The height at which the unclaimable period started. +// Set to 0 when the pool is not in an unclaimable period. +// +// - unclaimableAcc: +// The accumulated undisributed unclaimable reward. +// Reset to 0 when processUnclaimableReward is called and sent to community pool. +// +// - rewardCache: +// The cached per-block reward emitted for this pool. +// Stores new entry only when the reward is changed. +// PoolTier.cacheReward() updates this. +// +// - incentives: The external incentives associated with the pool. +// +// - ticks: The Ticks associated with the pool. +// +// - globalRewardRatioAccumulation: +// Global ratio of BlockNumber / TotalStake accumulation(since the pool creation) +// Stores new entry only when tick cross or stake/unstake happens. +// It is used to calculate the reward for a staked position at certain height. +// +// - historicalTick: +// The historical tick for the pool at a given height. +// It does not reflect the exact tick at the blockNumber, +// but it provides correct ordering for the staked position's ticks. +// Therefore, you should not compare it for equality, only for ordering. +// Set when tick cross happens or a new position is created. +type Pool struct { + poolPath string + + stakedLiquidity *UintTree // uint64 blockNumber -> *u256.Uint(Q128) + + lastUnclaimableHeight int64 + unclaimableAcc uint64 + + rewardCache *UintTree // uint64 blockNumber -> uint64 gnsReward + + incentives Incentives + + ticks Ticks // int32 tickId -> Tick tick + + globalRewardRatioAccumulation *UintTree // uint64 blockNumber -> *u256.Uint(Q128) rewardRatioAccumulation + + historicalTick *UintTree // uint64 blockNumber -> int32 tickId +} + +// NewPool creates a new pool with the given poolPath and currentHeight. +func NewPool(poolPath string, currentHeight int64) *Pool { + pool := &Pool{ + poolPath: poolPath, + stakedLiquidity: NewUintTree(), + lastUnclaimableHeight: currentHeight, + unclaimableAcc: 0, + rewardCache: NewUintTree(), + incentives: NewIncentives(), + ticks: NewTicks(), + globalRewardRatioAccumulation: NewUintTree(), + historicalTick: NewUintTree(), + } + + pool.globalRewardRatioAccumulation.Set(currentHeight, u256.Zero()) + pool.rewardCache.Set(currentHeight, uint64(0)) + pool.stakedLiquidity.Set(currentHeight, u256.Zero()) + + return pool +} + +// Get the latest global reward ratio accumulation in [0, currentHeight] range. +// Returns the height and the accumulation. +func (self *Pool) CurrentGlobalRewardRatioAccumulation(currentHeight int64) (int64, *u256.Uint) { + var height int64 + var acc *u256.Uint + self.globalRewardRatioAccumulation.ReverseIterate(0, currentHeight, func(key int64, value interface{}) bool { + height = key + acc = value.(*u256.Uint) + return true + }) + if acc == nil { + panic("should not happen, globalRewardRatioAccumulation must be set when pool is created") + } + return height, acc +} + +// Get the latest tick in [0, currentHeight] range. +// Returns the tick. +func (self *Pool) CurrentTick(currentHeight int64) int32 { + var tick int32 + self.historicalTick.ReverseIterate(0, currentHeight, func(key int64, value interface{}) bool { + tick = value.(int32) + return true + }) + return tick +} + +func (self *Pool) CurrentStakedLiquidity(currentHeight int64) *u256.Uint { + var liquidity *u256.Uint + self.stakedLiquidity.ReverseIterate(0, currentHeight, func(key int64, value interface{}) bool { + liquidity = value.(*u256.Uint) + return true + }) + return liquidity +} + +// IsExternallyIncentivizedPool returns true if the pool has any external incentives. +func (self *Pool) IsExternallyIncentivizedPool() bool { + return self.incentives.byTime.Size() != 0 +} + +// Get the latest reward in [0, currentHeight] range. +// Returns the reward. +func (self *Pool) CurrentReward(currentHeight int64) uint64 { + var reward uint64 + self.rewardCache.ReverseIterate(0, currentHeight, func(key int64, value interface{}) bool { + reward = value.(uint64) + return true + }) + return reward +} + +// cacheReward sets the current reward for the pool +// If the pool is in unclaimable period, it will end the unclaimable period, updates the reward, and start the unclaimable period again. +func (self *Pool) cacheReward(currentHeight int64, currentTierReward uint64) { + oldTierReward := self.CurrentReward(currentHeight) + if oldTierReward == currentTierReward { + return + } + + isInUnclaimable := self.CurrentStakedLiquidity(currentHeight).IsZero() + if isInUnclaimable { + self.endUnclaimablePeriod(currentHeight) + } + + self.rewardCache.Set(currentHeight, currentTierReward) + + if isInUnclaimable { + self.startUnclaimablePeriod(currentHeight) + } +} + +// cacheInternalReward caches the current emission and updates the global reward ratio accumulation. +func (self *Pool) cacheInternalReward(currentHeight int64, currentEmission uint64) { + self.cacheReward(currentHeight, currentEmission) + + currentStakedLiquidity := self.CurrentStakedLiquidity(currentHeight) + if currentStakedLiquidity.IsZero() { + self.endUnclaimablePeriod(currentHeight) + self.startUnclaimablePeriod(currentHeight) + } + + self.updateGlobalRewardRatioAccumulation(currentHeight, currentStakedLiquidity) +} + +func (self *Pool) calculateGlobalRewardRatioAccumulation(currentHeight int64, currentStakedLiquidity *u256.Uint) *u256.Uint { + oldAccHeight, oldAcc := self.CurrentGlobalRewardRatioAccumulation(currentHeight) + blockDiff := currentHeight - oldAccHeight + if blockDiff == 0 { + return oldAcc.Clone() + } + + acc := u256.NewUint(uint64(blockDiff)) + acc = acc.Mul(acc, q128) + acc = acc.Div(acc, currentStakedLiquidity) + + return u256.Zero().Add(oldAcc, acc) +} + +// updateGlobalRewardRatioAccumulation updates the global reward ratio accumulation and returns the new accumulation. +func (self *Pool) updateGlobalRewardRatioAccumulation(currentHeight int64, currentStakedLiquidity *u256.Uint) *u256.Uint { + newAcc := self.calculateGlobalRewardRatioAccumulation(currentHeight, currentStakedLiquidity) + + self.globalRewardRatioAccumulation.Set(currentHeight, newAcc) + return newAcc +} + +// RewardState is a struct for storing the intermediate state for reward calculation. +type RewardState struct { + pool *Pool + deposit *Deposit + + // accumulated rewards for each warmup + rewards []uint64 + penalties []uint64 +} + +// RewardStateOf initializes a new RewardState for the given deposit. +func (self *Pool) RewardStateOf(deposit *Deposit) *RewardState { + result := &RewardState{ + pool: self, + deposit: deposit, + rewards: make([]uint64, len(deposit.warmups)), + penalties: make([]uint64, len(deposit.warmups)), + } + + for i := range result.rewards { + result.rewards[i] = 0 + result.penalties[i] = 0 + } + + return result +} + +// CalculateInternalReward calculates the internal reward for the deposit. +// It calls RewardPerWarmup for each rewardCache interval, applies warmup, and returns the rewards and penalties. +func (self *RewardState) CalculateInternalReward(startHeight, endHeight int64) ([]uint64, []uint64) { + currentReward := self.pool.CurrentReward(startHeight) + self.pool.rewardCache.Iterate(startHeight, endHeight, func(key int64, value interface{}) bool { + // we calculate per-position reward + self.RewardPerWarmup(startHeight, int64(key), currentReward) + currentReward = value.(uint64) + startHeight = int64(key) + return false + }) + + if startHeight < endHeight { + self.RewardPerWarmup(startHeight, endHeight, currentReward) + } + + self.ApplyWarmup() + + return self.rewards, self.penalties +} + +// CalculateExternalReward calculates the external reward for the deposit. +// It calls RewardPerWarmup for startHeight to endHeight(clamped to the incentive period), applies warmup and returns the rewards and penalties. +func (self *RewardState) CalculateExternalReward(startHeight, endHeight int64, incentive *ExternalIncentive) ([]uint64, []uint64) { + if startHeight < int64(self.deposit.lastCollectHeight) { + // This must not happen, but adding some guards just in case. + startHeight = int64(self.deposit.lastCollectHeight) + } + + if endHeight < incentive.startHeight { + return nil, nil // Not started yet + } + + if startHeight < incentive.startHeight { + startHeight = incentive.startHeight + } + + if endHeight > incentive.endHeight { + endHeight = incentive.endHeight + } + + if startHeight > incentive.endHeight { + return nil, nil // Already ended + } + + rewardPerBlock := incentive.rewardPerBlock + + self.RewardPerWarmup(startHeight, endHeight, rewardPerBlock) + + self.ApplyWarmup() + + return self.rewards, self.penalties +} + +// ApplyWarmup applies the warmup to the rewards and calculate penalties. +func (self *RewardState) ApplyWarmup() { + for i, warmup := range self.deposit.warmups { + refactorReward := self.rewards[i] + self.rewards[i] = refactorReward * warmup.WarmupRatio / 100 + self.penalties[i] = refactorReward - self.rewards[i] + } +} + +// RewardPerWarmup calculates the reward for each warmup, adds to the RewardState's rewards array. +func (self *RewardState) RewardPerWarmup(startHeight, endHeight int64, rewardPerBlock uint64) { + for i, warmup := range self.deposit.warmups { + if startHeight >= warmup.NextWarmupHeight { + // passed the warmup + continue + } + + if endHeight < warmup.NextWarmupHeight { + rewardAcc := self.pool.CalculateRewardForPosition(startHeight, self.pool.CurrentTick(startHeight), endHeight, self.pool.CurrentTick(endHeight), self.deposit) + + rewardAcc = rewardAcc.Mul(rewardAcc, self.deposit.liquidity) + rewardAcc = rewardAcc.Mul(rewardAcc, u256.NewUint(rewardPerBlock)) + rewardAcc = rewardAcc.Div(rewardAcc, q128) + self.rewards[i] += rewardAcc.Uint64() + + // done + break + } + + rewardAcc := self.pool.CalculateRewardForPosition(startHeight, self.pool.CurrentTick(startHeight), warmup.NextWarmupHeight, self.pool.CurrentTick(warmup.NextWarmupHeight), self.deposit) + + rewardAcc = rewardAcc.Mul(rewardAcc, self.deposit.liquidity) + rewardAcc = rewardAcc.Mul(rewardAcc, u256.NewUint(rewardPerBlock)) + rewardAcc = rewardAcc.Div(rewardAcc, q128) + self.rewards[i] += rewardAcc.Uint64() + + startHeight = warmup.NextWarmupHeight + } +} + +// modifyDeposit updates the pool's staked liquidity and returns the new staked liquidity. +// updates when there is a change in the staked liquidity(tick cross, stake, unstake) +func (self *Pool) modifyDeposit(delta *i256.Int, currentHeight int64, nextTick int32) *u256.Uint { + // update staker side pool info + lastStakedLiquidity := self.CurrentStakedLiquidity(currentHeight) + deltaApplied := liquidityMathAddDelta(lastStakedLiquidity, delta) + result := self.updateGlobalRewardRatioAccumulation(currentHeight, lastStakedLiquidity) + + // historical tick does NOT actually reflect the tick at the blockNumber, but it provides correct ordering for the staked positions + // because TickCrossHook is assured to be called for the staked-initialized ticks + self.historicalTick.Set(currentHeight, nextTick) + + switch deltaApplied.Sign() { + case -1: + panic("stakedLiquidity is less than 0, should not happen") + case 0: + if lastStakedLiquidity.Sign() == 1 { + // StakedLiquidity moved from positive to zero, start unclaimable period + self.startUnclaimablePeriod(currentHeight) + self.incentives.startUnclaimablePeriod(currentHeight) + } + case 1: + if lastStakedLiquidity.Sign() == 0 { + // StakedLiquidity moved from zero to positive, end unclaimable period + self.endUnclaimablePeriod(currentHeight) + self.incentives.endUnclaimablePeriod(currentHeight) + } + } + + self.stakedLiquidity.Set(currentHeight, deltaApplied) + + return result +} + +// startUnclaimablePeriod starts the unclaimable period. +func (self *Pool) startUnclaimablePeriod(currentHeight int64) { + if self.lastUnclaimableHeight == 0 { + // We set only if it's the first time entering(0 indicates not set yet) + self.lastUnclaimableHeight = currentHeight + } +} + +// endUnclaimablePeriod ends the unclaimable period. +// Accumulates to unclaimableAcc and resets lastUnclaimableHeight to 0. +func (self *Pool) endUnclaimablePeriod(currentHeight int64) { + if self.lastUnclaimableHeight == 0 { + // This should not happen, but guarding just in case + return + } + unclaimableHeights := currentHeight - self.lastUnclaimableHeight + self.unclaimableAcc += uint64(unclaimableHeights) * self.CurrentReward(self.lastUnclaimableHeight) + self.lastUnclaimableHeight = 0 +} + +// processUnclaimableReward processes the unclaimable reward and returns the accumulated reward. +// It resets unclaimableAcc to 0 and updates lastUnclaimableHeight to endHeight. +func (self *Pool) processUnclaimableReward(poolTier *PoolTier, endHeight int64) uint64 { + internalUnClaimable := self.unclaimableAcc + self.unclaimableAcc = 0 + self.lastUnclaimableHeight = endHeight + return internalUnClaimable +} + +// Calculates reward for a position *without* considering debt or warmup +// It calculates the theoretical total reward for the position if it has been staked since the pool creation +func (self *Pool) CalculateRawRewardForPosition(currentHeight int64, currentTick int32, deposit *Deposit) *u256.Uint { + var rewardAcc *u256.Uint + + globalAcc := self.calculateGlobalRewardRatioAccumulation(currentHeight, self.CurrentStakedLiquidity(currentHeight)) + lowerAcc := self.ticks.Get(deposit.tickLower).CurrentOutsideAccumulation(currentHeight) + upperAcc := self.ticks.Get(deposit.tickUpper).CurrentOutsideAccumulation(currentHeight) + if currentTick < deposit.tickLower { + rewardAcc = u256.Zero().Sub(lowerAcc, upperAcc) + } else if currentTick >= deposit.tickUpper { + rewardAcc = u256.Zero().Sub(upperAcc, lowerAcc) + } else { + rewardAcc = u256.Zero().Sub(globalAcc, lowerAcc) + rewardAcc = rewardAcc.Sub(rewardAcc, upperAcc) + } + + return rewardAcc +} + +// Calculate actual reward in [startHeight, endHeight) for a position by +// subtracting the startHeight's raw reward from the endHeight's raw reward +func (self *Pool) CalculateRewardForPosition(startHeight int64, startTick int32, endHeight int64, endTick int32, deposit *Deposit) *u256.Uint { + rewardAcc := self.CalculateRawRewardForPosition(endHeight, endTick, deposit) + + debtAcc := self.CalculateRawRewardForPosition(startHeight, startTick, deposit) + + rewardAcc = rewardAcc.Sub(rewardAcc, debtAcc) + + return rewardAcc +} diff --git a/contract/r/gnoswap/staker/reward_calculation_pool_tier.gno b/contract/r/gnoswap/staker/reward_calculation_pool_tier.gno new file mode 100644 index 000000000..2a7041a83 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_pool_tier.gno @@ -0,0 +1,279 @@ +package staker + +import ( + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" +) + +const ( + AllTierCount = 4 // 0, 1, 2, 3 + Tier1 = 1 + Tier2 = 2 + Tier3 = 3 +) + +// 100%, 0%, 0% if no tier2 and tier3 +// 80%, 0%, 20% if no tier2 +// 70%, 30%, 0% if no tier3 +// 50%, 30%, 20% if has tier2 and tier3 +type TierRatio struct { + Tier1 uint64 + Tier2 uint64 + Tier3 uint64 +} + +// TierRatioFromCounts calculates the ratio distribution for each tier based on pool counts. +// +// Parameters: +// - tier1Count (uint64): Number of pools in tier 1. +// - tier2Count (uint64): Number of pools in tier 2. +// - tier3Count (uint64): Number of pools in tier 3. +// +// Returns: +// - TierRatio: The ratio distribution across tier 1, 2, and 3, scaled up by 100. +func TierRatioFromCounts(tier1Count, tier2Count, tier3Count uint64) TierRatio { + // tier1 always exists + if tier2Count == 0 && tier3Count == 0 { + return TierRatio{ + Tier1: 100, + Tier2: 0, + Tier3: 0, + } + } + if tier2Count == 0 { + return TierRatio{ + Tier1: 80, + Tier2: 0, + Tier3: 20, + } + } + if tier3Count == 0 { + return TierRatio{ + Tier1: 70, + Tier2: 30, + Tier3: 0, + } + } + return TierRatio{ + Tier1: 50, + Tier2: 30, + Tier3: 20, + } +} + +// Get returns the ratio(scaled up by 100) for the given tier. +func (self *TierRatio) Get(tier uint64) uint64 { + switch tier { + case Tier1: + return self.Tier1 + case Tier2: + return self.Tier2 + case Tier3: + return self.Tier3 + default: + panic(addDetailToError( + errInvalidPoolTier, ufmt.Sprintf("unsupported tier(%d)", tier))) + } +} + +// PoolTier manages pool counts, ratios, and rewards for different tiers. +// +// Fields: +// - membership: Tracks which tier a pool belongs to (poolPath -> blockNumber -> tier). +// +// Methods: +// - CurrentCount: Returns the current count of pools in a tier at a specific height. +// - CurrentRatio: Returns the current ratio for a tier at a specific height. +// - CurrentTier: Returns the tier of a specific pool at a given height. +// - CurrentReward: Retrieves the reward for a tier at a specific height. +// - changeTier: Updates the tier of a pool and recalculates ratios. +type PoolTier struct { + membership *avl.Tree // poolPath -> tier(1, 2, 3) + + tierRatio TierRatio + + lastRewardCacheHeight int64 + + currentEmission uint64 + + // returns current emission. + getEmission func() uint64 + + // returns a list of halving blocks within the interval [start, end) in ascending order + // there MUST NOT be any emission amount change between start and end - those had to be handled by the CallbackStakerEmissionChange. + getHalvingBlocksInRange func(start, end int64) ([]int64, []uint64) +} + +// NewPoolTier creates a new PoolTier instance with single initial 1 tier pool. +// +// Parameters: +// - pools: The pool collection. +// - currentHeight: The current block height. +// - initialPoolPath: The path of the initial pool. +// - getEmission: A function that returns the current emission to the staker contract. +// - getHalvingBlocksInRange: A function that returns a list of halving blocks within the interval [start, end) in ascending order. +// +// Returns: +// - *PoolTier: The new PoolTier instance. +func NewPoolTier(pools *Pools, currentHeight int64, initialPoolPath string, getEmission func() uint64, getHalvingBlocksInRange func(start, end int64) ([]int64, []uint64)) *PoolTier { + result := &PoolTier{ + membership: avl.NewTree(), + tierRatio: TierRatioFromCounts(1, 0, 0), + lastRewardCacheHeight: currentHeight + 1, + getEmission: getEmission, + getHalvingBlocksInRange: getHalvingBlocksInRange, + currentEmission: getEmission(), + } + + pools.Set(initialPoolPath, NewPool(initialPoolPath, currentHeight+1)) + result.changeTier(currentHeight+1, pools, initialPoolPath, 1) + return result +} + +// CurrentReward returns the current per-pool reward for the given tier. +func (self *PoolTier) CurrentReward(tier uint64) uint64 { + currentEmission := self.getEmission() + return currentEmission * self.tierRatio.Get(tier) / uint64(self.CurrentCount(tier)) / 100 +} + +// CurrentCount returns the current count of pools in the given tier. +func (self *PoolTier) CurrentCount(tier uint64) int { + count := 0 + self.membership.Iterate("", "", func(key string, value interface{}) bool { + if value.(uint64) == tier { + count++ + } + return false + }) + return count +} + +// CurrentAllTierCounts returns the current count of pools in each tier. +func (self *PoolTier) CurrentAllTierCounts() []uint64 { + count := make([]uint64, AllTierCount) + self.membership.Iterate("", "", func(key string, value interface{}) bool { + count[value.(uint64)]++ + return false + }) + return count +} + +// CurrentTier returns the tier of the given pool. +func (self *PoolTier) CurrentTier(poolPath string) uint64 { + tier, ok := self.membership.Get(poolPath) + if !ok { + return 0 + } + return tier.(uint64) +} + +// changeTier updates the tier of a pool, recalculates ratios, and applies +// updated per-pool reward to each of the pools. +func (self *PoolTier) changeTier(currentHeight int64, pools *Pools, poolPath string, nextTier uint64) { + self.cacheReward(currentHeight, pools) + // same as prev. no need to update + currentTier := self.CurrentTier(poolPath) + if currentTier == nextTier { + // no change, return + return + } + + if nextTier == 0 { + // removed from the tier + self.membership.Remove(poolPath) + pool, ok := pools.Get(poolPath) + if !ok { + panic("changeTier: pool not found") + } + // caching reward to 0 + pool.cacheReward(currentHeight, 0) + } else { + self.membership.Set(poolPath, nextTier) + } + + counts := self.CurrentAllTierCounts() + self.tierRatio = TierRatioFromCounts(counts[Tier1], counts[Tier2], counts[Tier3]) + + currentEmission := self.getEmission() + + // Cache updated reward for each tiered pool + self.membership.Iterate("", "", func(key string, value interface{}) bool { + pool, ok := pools.Get(key) + if !ok { + panic("changeTier: pool not found") + } + tier := value.(uint64) + poolReward := currentEmission * self.tierRatio.Get(tier) / 100 / counts[tier] + pool.cacheReward(currentHeight, poolReward) + return false + }) + + self.currentEmission = currentEmission +} + +// cacheReward MUST be called before calculating any position reward +// cacheReward updates the reward cache for each pools, accounting for any halving event in between the last cached height and the current height. +func (self *PoolTier) cacheReward(currentHeight int64, pools *Pools) { + lastHeight := self.lastRewardCacheHeight + + if currentHeight <= lastHeight { + // no need to check + return + } + + // find halving blocks in range + halvingBlocks, halvingEmissions := self.getHalvingBlocksInRange(lastHeight, currentHeight) + + if len(halvingBlocks) == 0 { + self.applyCacheToAllPools(pools, currentHeight, self.currentEmission) + return + } + + for i, hvBlock := range halvingBlocks { + // caching: [lastHeight, hvBlock) + self.applyCacheToAllPools(pools, hvBlock, halvingEmissions[i]) + + // halve emissions when halvingBlock is reached + self.currentEmission = halvingEmissions[i] + } + + // remaining range [lastHalvingBlock, currentHeight) + self.applyCacheToAllPools(pools, currentHeight, self.currentEmission) + + // update lastRewardCacheHeight and currentEmission + self.lastRewardCacheHeight = currentHeight +} + +// applyCacheToAllPools applies the cached reward to all tiered pools. +func (self *PoolTier) applyCacheToAllPools(pools *Pools, currentBlock int64, emissionInThisInterval uint64) { + // calculate denominator and number of pools in each tier + counts := self.CurrentAllTierCounts() + + // apply cache to all pools + self.membership.Iterate("", "", func(key string, value interface{}) bool { + tierNum := value.(uint64) + pool, ok := pools.Get(key) + if !ok { + return false + } + + // real reward + reward := emissionInThisInterval * self.tierRatio.Get(tierNum) / 100 / counts[tierNum] + // accumulate the reward for the interval (startBlock to endBlock) in the Pool + pool.cacheInternalReward(currentBlock, reward) + return false + }) +} + +// IsInternallyIncentivizedPool returns true if the pool is in a tier. +func (self *PoolTier) IsInternallyIncentivizedPool(poolPath string) bool { + return self.CurrentTier(poolPath) > 0 +} + +func (self *PoolTier) CurrentRewardPerPool(poolPath string) uint64 { + emission := self.getEmission() + counts := self.CurrentAllTierCounts() + tierNum := self.CurrentTier(poolPath) + reward := emission * self.tierRatio.Get(tierNum) / 100 / counts[tierNum] + return reward +} diff --git a/contract/r/gnoswap/staker/reward_calculation_pool_tier_test.gno b/contract/r/gnoswap/staker/reward_calculation_pool_tier_test.gno new file mode 100644 index 000000000..8714f1f49 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_pool_tier_test.gno @@ -0,0 +1,132 @@ +package staker + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/gnoswap/consts" + pl "gno.land/r/gnoswap/v1/pool" + "gno.land/r/onbloc/bar" + "gno.land/r/onbloc/baz" + "gno.land/r/onbloc/qux" +) + +func TestTierRatioFromCounts(t *testing.T) { + CreateSecondPoolWithoutFee(t) + MakeMintPositionWithoutFee(t) + + user1Addr := testutils.TestAddress("user1") + user1Realm := std.NewUserRealm(user1Addr) + + std.TestSetRealm(user1Realm) + bar.Approve(consts.ROUTER_ADDR, maxApprove) + baz.Approve(consts.ROUTER_ADDR, maxApprove) + qux.Approve(consts.ROUTER_ADDR, maxApprove) + TokenFaucet(t, barPath, user1Addr) + + tests := []struct { + tier1Count uint64 + tier2Count uint64 + tier3Count uint64 + expected TierRatio + }{ + {1, 0, 0, TierRatio{Tier1: 100, Tier2: 0, Tier3: 0}}, + {1, 0, 1, TierRatio{Tier1: 80, Tier2: 0, Tier3: 20}}, + {1, 1, 0, TierRatio{Tier1: 70, Tier2: 30, Tier3: 0}}, + {1, 1, 1, TierRatio{Tier1: 50, Tier2: 30, Tier3: 20}}, + } + + for _, tt := range tests { + result := TierRatioFromCounts(tt.tier1Count, tt.tier2Count, tt.tier3Count) + if result != tt.expected { + t.Errorf("TierRatioFromCounts(%d, %d, %d) = %v; want %v", + tt.tier1Count, tt.tier2Count, tt.tier3Count, result, tt.expected) + } + } +} + +func TestNewPoolTier(t *testing.T) { + currentHeight := int64(100) + mustExistsInTier1 := "testPool" + + poolTier := NewPoolTier(NewPools(), currentHeight, mustExistsInTier1, func() uint64 { return 0 }, func(start, end int64) ([]int64, []uint64) { return nil, nil }) + + // Test initial counts + if count := poolTier.CurrentCount(1); count != 1 { + t.Errorf("Expected tier 1 count to be 1, got %d", count) + } + if count := poolTier.CurrentCount(2); count != 0 { + t.Errorf("Expected tier 2 count to be 0, got %d", count) + } + if count := poolTier.CurrentCount(3); count != 0 { + t.Errorf("Expected tier 3 count to be 0, got %d", count) + } + + // Test membership + if tier := poolTier.CurrentTier(mustExistsInTier1); tier != 1 { + t.Errorf("Expected pool %s to be in tier 1, got %d", mustExistsInTier1, tier) + } +} + +func TestCacheReward(t *testing.T) { + currentHeight := int64(250) + // Simulate emission updates + poolTier := NewPoolTier(NewPools(), currentHeight, "testPool", func() uint64 { return 1000 }, func(startHeight int64, endHeight int64) ([]int64, []uint64) { + return []int64{100, 150, 200}, []uint64{1000, 500, 250} + }) + + // Cache rewards + poolTier.cacheReward(currentHeight, pools) + + // Verify rewards + reward := poolTier.CurrentReward(1) + if reward == 0 { + t.Errorf("Expected reward for tier 1 at height 250, got 0") + } +} + +var test_gnousdc = pl.GetPoolPath("gno.land/r/demo/wugnot", "gno.land/r/gnoswap/v1/gns", 3000) + +func SetupPoolTier(t *testing.T) *PoolTier { + poolTier := NewPoolTier(NewPools(), 1, test_gnousdc, func() uint64 { return 1000 }, func(start, end int64) ([]int64, []uint64) { return nil, nil }) + + poolTier.changeTier(1, pools, test_gnousdc, 1) + return poolTier +} + +/* +func TestPoolTierSimple(t *testing.T) { + poolTier := SetupPoolTier(t) + + currentHeight := uint64(1) + + emissionUpdate := EmissionUpdate { + LastEmissionUpdate: 10000, + EmissionUpdateHeights: []uint64{currentHeight}, + EmissionUpdates: []uint64{10000}, + } + + currentHeight = 2 + + poolTier.cacheReward(currentHeight, pools) + uassert.Equal(t, uint64(1), poolTier.CurrentTier(test_gnousdc, currentHeight)) + uassert.Equal(t, uint64(10000), poolTier.CurrentReward(1, currentHeight)) + + currentHeight = 3 + + emissionUpdate.LastEmissionUpdate = 10000 + emissionUpdate.EmissionUpdateHeights = append(emissionUpdate.EmissionUpdateHeights, currentHeight) + emissionUpdate.EmissionUpdates = append(emissionUpdate.EmissionUpdates, 20000) + + gnousdc2 := pl.GetPoolPath("gno.land/r/demo/wugnot", "gno.land/r/gnoswap/v1/gns", 5000) + poolTier.changeTier(currentHeight, gnousdc2, 2) + + currentHeight = 4 + + poolTier.cacheReward(currentHeight, emissionUpdate) + uassert.Equal(t, uint64(1), poolTier.CurrentTier(test_gnousdc, currentHeight)) + uassert.Equal(t, uint64(14000), poolTier.CurrentReward(1, currentHeight)) + uassert.Equal(t, uint64(6000), poolTier.CurrentReward(2, currentHeight)) +} +*/ diff --git a/contract/r/gnoswap/staker/reward_calculation_tick.gno b/contract/r/gnoswap/staker/reward_calculation_tick.gno new file mode 100644 index 000000000..2fab09994 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_tick.gno @@ -0,0 +1,182 @@ +package staker + +import ( + "std" + + "strconv" + "strings" + + "gno.land/p/demo/avl" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + pl "gno.land/r/gnoswap/v1/pool" +) + +// EncodeInt takes an int32 and returns a zero-padded decimal string +// with up to 10 digits for the absolute value. +// If the number is negative, the '-' sign comes first, followed by zeros, then digits. +func EncodeInt(num int32) string { + // Convert the absolute value to a decimal string. + absValue := int64(num) + isNegative := false + if num < 0 { + isNegative = true + absValue = -absValue // Safely negate into int64 to avoid overflow. + } + + s := strconv.FormatInt(absValue, 10) + + // Zero-pad to a total of 10 digits for the absolute value. + // (The '-' sign will be added later if needed.) + zerosNeeded := 10 - len(s) + if zerosNeeded < 0 { + zerosNeeded = 0 + } + + padded := strings.Repeat("0", zerosNeeded) + s + + // If the original number was negative, prepend '-'. + if isNegative { + return "-" + padded + } + return padded +} + +// Tick mapping for each pool +type Ticks struct { + tree *avl.Tree // int32 tickId -> tick +} + +func NewTicks() Ticks { + return Ticks{ + tree: avl.NewTree(), + } +} + +func (self *Ticks) Get(tickId int32) *Tick { + v, ok := self.tree.Get(EncodeInt(tickId)) + if !ok { + tick := &Tick{ + id: tickId, + stakedLiquidityGross: u256.Zero(), + stakedLiquidityDelta: i256.Zero(), + outsideAccumulation: NewUintTree(), + } + self.tree.Set(EncodeInt(tickId), tick) + return tick + } + return v.(*Tick) +} + +func (self *Ticks) Set(tickId int32, tick *Tick) { + if tick.stakedLiquidityGross.IsZero() { + self.tree.Remove(EncodeInt(tickId)) + return + } + self.tree.Set(EncodeInt(tickId), tick) +} + +func (self *Ticks) Has(tickId int32) bool { + return self.tree.Has(EncodeInt(tickId)) +} + +// Tick represents the state of a specific tick in a pool. +// +// Fields: +// - id (int32): The ID of the tick. +// - stakedLiquidityGross (*u256.Uint): Total gross staked liquidity at this tick. +// - stakedLiquidityDelta (*i256.Int): Net change in staked liquidity at this tick. +// - outsideAccumulation (*UintTree): RewardRatioAccumulation outside the tick. +type Tick struct { + id int32 + + // conceptually equal with Pool.liquidityGross but only for the staked positions + stakedLiquidityGross *u256.Uint + + // conceptually equal with Pool.liquidityNet but only for the staked positions + stakedLiquidityDelta *i256.Int + + // currentOutsideAccumulation is the accumulation of the blockNumber / TotalStake outside the tick. + // It is calculated by subtracting the current tick's currentOutsideAccumulation from the global reward ratio accumulation. + outsideAccumulation *UintTree // blockNumber -> *u256.Uint +} + +// CurrentOutsideAccumulation returns the latest outside accumulation for the tick +func (self *Tick) CurrentOutsideAccumulation(blockNumber int64) *u256.Uint { + var acc *u256.Uint + self.outsideAccumulation.ReverseIterate(0, blockNumber, func(key int64, value interface{}) bool { + acc = value.(*u256.Uint) + return true + }) + if acc == nil { + acc = u256.Zero() + } + return acc +} + +// modifyDepositLower updates the tick's liquidity info by treating the deposit as a lower tick +func (self *Tick) modifyDepositLower(currentHeight int64, currentTick int32, liquidity *i256.Int) { + // update staker side tick info + self.stakedLiquidityGross = liquidityMathAddDelta(self.stakedLiquidityGross, liquidity) + if self.stakedLiquidityGross.Lt(u256.Zero()) { + panic("stakedLiquidityGross is negative") + } + self.stakedLiquidityDelta = i256.Zero().Add(self.stakedLiquidityDelta, liquidity) +} + +// modifyDepositUpper updates the tick's liquidity info by treating the deposit as an upper tick +func (self *Tick) modifyDepositUpper(currentHeight int64, currentTick int32, liquidity *i256.Int) { + self.stakedLiquidityGross = liquidityMathAddDelta(self.stakedLiquidityGross, liquidity) + if self.stakedLiquidityGross.Lt(u256.Zero()) { + panic("stakedLiquidityGross is negative") + } + self.stakedLiquidityDelta = i256.Zero().Sub(self.stakedLiquidityDelta, liquidity) +} + +// updateCurrentOutsideAccumulation updates the tick's outside accumulation +// It "flips" the accumulation's inside/outside by subtracting the current outside accumulation from the global accumulation +func (self *Tick) updateCurrentOutsideAccumulation(blockNumber int64, acc *u256.Uint) { + currentOutsideAccumulation := self.CurrentOutsideAccumulation(blockNumber) + newOutsideAccumulation := u256.Zero().Sub(acc, currentOutsideAccumulation) + self.outsideAccumulation.Set(blockNumber, newOutsideAccumulation) +} + +// TickCrossHook is a hook that is called when a tick is crossed. +// Modifies the tick's outside accumulation and updates the tick's liquidity info +func TickCrossHook(pools *Pools, height func() int64) func(poolPath string, tickId int32, zeroForOne bool) { + return func(poolPath string, tickId int32, zeroForOne bool) { + pool, ok := pools.Get(poolPath) + if !ok { + return + } + + tick := pool.ticks.Get(tickId) + + blockNumber := height() + + var nextTick int32 + if zeroForOne { + nextTick = tickId - 1 + } else { + nextTick = tickId + } + + liquidityInRangeDelta := tick.stakedLiquidityDelta + if liquidityInRangeDelta.Sign() == 0 { + return + } + + if zeroForOne { + liquidityInRangeDelta = i256.Zero().Neg(liquidityInRangeDelta) + } + newAcc := pool.modifyDeposit(liquidityInRangeDelta, blockNumber, nextTick) + tick.updateCurrentOutsideAccumulation(blockNumber, newAcc) + } +} + +func init() { + // Sets tick cross hook for pool contract + pl.SetTickCrossHook(TickCrossHook(pools, std.GetHeight)) +} diff --git a/contract/r/gnoswap/staker/reward_calculation_tick_test.gno b/contract/r/gnoswap/staker/reward_calculation_tick_test.gno new file mode 100644 index 000000000..354349f60 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_tick_test.gno @@ -0,0 +1,132 @@ +package staker + +import ( + "strconv" + "testing" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/p/demo/uassert" +) + +func TestEncodeInt(t *testing.T) { + tests := []struct { + input int32 + expected string + }{ + {123, "0000000123"}, + {-123, "-0000000123"}, + {0, "0000000000"}, + {2147483647, "2147483647"}, // int32 max + {-2147483648, "-2147483648"}, // int32 min + } + + for _, tt := range tests { + t.Run(strconv.Itoa(int(tt.input)), func(t *testing.T) { + uassert.Equal(t, EncodeInt(tt.input), tt.expected) + }) + } +} + +func TestTicks(t *testing.T) { + ticks := NewTicks() + + tick := ticks.Get(100) + if tick == nil || tick.id != 100 { + t.Errorf("Get(100) returned %v; want Tick with ID 100", tick) + } + + tick.stakedLiquidityGross = u256.MustFromDecimal("1") + ticks.Set(100, tick) + uassert.True(t, ticks.Has(100)) + + tick.stakedLiquidityGross = u256.Zero() + ticks.Set(100, tick) + uassert.False(t, ticks.Has(100)) +} + +func TestTicksBasic(t *testing.T) { + ticks := NewTicks() + + tick100 := ticks.Get(100) + uassert.True(t, ticks.Has(100)) + uassert.Equal(t, tick100.id, int32(100)) + + tick100Again := ticks.Get(100) + uassert.Equal(t, int32(tick100Again.id), int32(tick100.id)) + uassert.True(t, tick100Again.stakedLiquidityGross.IsZero()) + uassert.True(t, tick100Again.stakedLiquidityDelta.IsZero()) + + ticks.Set(100, tick100) + uassert.False(t, ticks.Has(100)) +} + +func TestModifyDepositLower(t *testing.T) { + ticks := NewTicks() + tick := ticks.Get(100) + + // initial value must be zero + uassert.True(t, tick.stakedLiquidityGross.IsZero()) + uassert.True(t, tick.stakedLiquidityDelta.IsZero()) + + // deposit +10 + liquidityDelta := i256.NewInt(10) // +10 + tick.modifyDepositLower(50, 95, liquidityDelta) + + // stakedLiquidityGross += +10 => 10 + // stakedLiquidityDelta += +10 => 10 + if tick.stakedLiquidityGross.ToString() != "10" || tick.stakedLiquidityDelta.ToString() != "10" { + t.Errorf("After deposit +10, stakedLiquidityGross=%v, stakedLiquidityDelta=%v; want 10,10", + tick.stakedLiquidityGross, tick.stakedLiquidityDelta) + } + + // deposit another +5 + tick.modifyDepositLower(60, 95, i256.NewInt(5)) + // gross=15, delta=15 + if tick.stakedLiquidityGross.ToString() != "15" || tick.stakedLiquidityDelta.ToString() != "15" { + t.Errorf("After deposit +5, stakedLiquidityGross=%v, stakedLiquidityDelta=%v; want 15,15", + tick.stakedLiquidityGross, tick.stakedLiquidityDelta) + } +} + +func TestModifyDepositUpper(t *testing.T) { + ticks := NewTicks() + tick := ticks.Get(200) + + // deposit +10 => modifyDepositUpper + tick.modifyDepositUpper(70, 195, i256.NewInt(10)) + + // stakedLiquidityGross=10, stakedLiquidityDelta = -10 + // upper => delta = stakedLiquidityDelta - liquidity + if tick.stakedLiquidityGross.ToString() != "10" || tick.stakedLiquidityDelta.ToString() != "-10" { + t.Errorf("After deposit +10(upper), stakedLiquidityGross=%v, stakedLiquidityDelta=%v; want 10,-10", + tick.stakedLiquidityGross, tick.stakedLiquidityDelta) + } +} + +func compareRangeSlices(t *testing.T, a, b [][2]uint64) bool { + t.Helper() + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func compareInt64Slices(t *testing.T, a, b []int64) bool { + t.Helper() + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/contract/r/gnoswap/staker/reward_calculation_types.gno b/contract/r/gnoswap/staker/reward_calculation_types.gno new file mode 100644 index 000000000..14f7cfbb0 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_types.gno @@ -0,0 +1,103 @@ +package staker + +import ( + "strconv" + "strings" + + "gno.land/p/demo/avl" +) + +// EncodeUint converts a uint64 number into a zero-padded 20-character string. +// +// Parameters: +// - num (uint64): The number to encode. +// +// Returns: +// - string: A zero-padded string representation of the number. +// +// Example: +// Input: 12345 +// Output: "00000000000000012345" +func EncodeUint(num uint64) string { + // Convert the value to a decimal string. + s := strconv.FormatUint(num, 10) + + // Zero-pad to a total length of 20 characters. + zerosNeeded := 20 - len(s) + return strings.Repeat("0", zerosNeeded) + s +} + +// DecodeUint converts a zero-padded string back into a uint64 number. +// +// Parameters: +// - s (string): The zero-padded string. +// +// Returns: +// - uint64: The decoded number. +// +// Panics: +// - If the string cannot be parsed into a uint64. +// +// Example: +// Input: "00000000000000012345" +// Output: 12345 +func DecodeUint(s string) uint64 { + num, err := strconv.ParseUint(s, 10, 64) + if err != nil { + panic(err) + } + return num +} + +// UintTree is a wrapper around an AVL tree for storing block heights as strings. +// Since block heights are defined as int64, we take int64 and convert it to uint64 for the tree. +// +// Methods: +// - Get: Retrieves a value associated with a uint64 key. +// - Set: Stores a value with a uint64 key. +// - Has: Checks if a uint64 key exists in the tree. +// - Remove: Removes a uint64 key and its associated value. +// - Iterate: Iterates over keys and values in a range. +// - ReverseIterate: Iterates in reverse order over keys and values in a range. +type UintTree struct { + tree *avl.Tree // blockNumber -> interface{} +} + +// NewUintTree creates a new UintTree instance. +func NewUintTree() *UintTree { + return &UintTree{ + tree: avl.NewTree(), + } +} + +func (self *UintTree) Get(key int64) (interface{}, bool) { + v, ok := self.tree.Get(EncodeUint(uint64(key))) + if !ok { + return nil, false + } + return v, true +} + +func (self *UintTree) Set(key int64, value interface{}) { + self.tree.Set(EncodeUint(uint64(key)), value) +} + +func (self *UintTree) Has(key int64) bool { + return self.tree.Has(EncodeUint(uint64(key))) +} + +func (self *UintTree) Remove(key int64) { + self.tree.Remove(EncodeUint(uint64(key))) +} + +func (self *UintTree) Iterate(start, end int64, fn func(key int64, value interface{}) bool) { + self.tree.Iterate(EncodeUint(uint64(start)), EncodeUint(uint64(end)), func(key string, value interface{}) bool { + return fn(int64(DecodeUint(key)), value) + }) +} + +func (self *UintTree) ReverseIterate(start, end int64, fn func(key int64, value interface{}) bool) { + self.tree.ReverseIterate(EncodeUint(uint64(start)), EncodeUint(uint64(end)), func(key string, value interface{}) bool { + return fn(int64(DecodeUint(key)), value) + }) +} diff --git a/contract/r/gnoswap/staker/reward_calculation_types_test.gno b/contract/r/gnoswap/staker/reward_calculation_types_test.gno new file mode 100644 index 000000000..7f165820d --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_types_test.gno @@ -0,0 +1,122 @@ +package staker + +import "testing" + +func TestEncodeUint(t *testing.T) { + tests := []struct { + input uint64 + expected string + }{ + {0, "00000000000000000000"}, // minimum + {12345, "00000000000000012345"}, // normal value + {18446744073709551615, "18446744073709551615"}, // maximum (uint64 max) + } + + for _, tt := range tests { + result := EncodeUint(tt.input) + if result != tt.expected { + t.Errorf("EncodeUint(%d) = %s; want %s", tt.input, result, tt.expected) + } + } +} + +func TestDecodeUint(t *testing.T) { + tests := []struct { + input string + expected uint64 + shouldPanic bool + }{ + {"00000000000000000000", 0, false}, + {"00000000000000012345", 12345, false}, + {"18446744073709551615", 18446744073709551615, false}, + {"invalid", 0, true}, + {"18446744073709551616", 0, true}, + } + + for _, tt := range tests { + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("DecodeUint(%s) did not panic as expected", tt.input) + } + }() + _ = DecodeUint(tt.input) + } else { + result := DecodeUint(tt.input) + if result != tt.expected { + t.Errorf("DecodeUint(%s) = %d; want %d", tt.input, result, tt.expected) + } + } + } +} +func TestUintTree(t *testing.T) { + tree := NewUintTree() + + // Test Set and Get + tree.Set(12345, "testValue") + value, ok := tree.Get(12345) + if !ok || value != "testValue" { + t.Errorf("UintTree.Get(12345) = %v, %v; want testValue, true", value, ok) + } + + // Test Has + if !tree.Has(12345) { + t.Errorf("UintTree.Has(12345) = false; want true") + } + + // Test Remove + tree.Remove(12345) + if tree.Has(12345) { + t.Errorf("UintTree.Has(12345) after Remove = true; want false") + } + + // Test Iterate + tree.Set(100, "a") + tree.Set(200, "b") + tree.Set(300, "c") + + var keys []uint64 + var values []interface{} + + tree.Iterate(100, 300, func(key int64, value interface{}) bool { + keys = append(keys, uint64(key)) + values = append(values, value) + return false + }) + + // Verify results + expectedKeys := []uint64{100, 200} + expectedValues := []interface{}{"a", "b"} + + if !compareUintSlices(t, keys, expectedKeys) || !compareInterfaces(t, values, expectedValues) { + t.Errorf("UintTree.Iterate() keys = %v, values = %v; want keys = %v, values = %v", keys, values, expectedKeys, expectedValues) + } +} + +// Helper function to compare slices of uint64 +func compareUintSlices(t *testing.T, a, b []uint64) bool { + t.Helper() + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// Helper function to compare slices of interfaces +func compareInterfaces(t *testing.T, a, b []interface{}) bool { + t.Helper() + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/contract/r/gnoswap/staker/reward_calculation_warmup.gno b/contract/r/gnoswap/staker/reward_calculation_warmup.gno new file mode 100644 index 000000000..46203cade --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_warmup.gno @@ -0,0 +1,110 @@ +package staker + +import ( + "math" + + "gno.land/p/demo/ufmt" + u256 "gno.land/p/gnoswap/uint256" + + gns "gno.land/r/gnoswap/v1/gns" +) + +type Warmup struct { + Index int + BlockDuration int64 + NextWarmupHeight int64 // set to 0 for template + WarmupRatio uint64 +} + +var warmupTemplate []Warmup = DefaultWarmupTemplate() + +func DefaultWarmupTemplate() []Warmup { + msInDay := int64(86400000) + blocksInDay := msInDay / int64(gns.GetAvgBlockTimeInMs()) + blocksIn5Days := int64(5 * blocksInDay) + blocksIn10Days := int64(10 * blocksInDay) + blocksIn30Days := int64(30 * blocksInDay) + + // NextWarmupHeights are set to 0 for template. + // It will be set by InstantiateWarmup() + return []Warmup{ + { + Index: 0, + BlockDuration: blocksIn5Days, + //NextWarmupHeight: currentHeight + blocksIn5Days, + WarmupRatio: 30, + }, + { + Index: 1, + BlockDuration: blocksIn10Days, + //NextWarmupHeight: currentHeight + blocksIn10Days, + WarmupRatio: 50, + }, + { + Index: 2, + BlockDuration: blocksIn30Days, + //NextWarmupHeight: currentHeight + blocksIn30Days, + WarmupRatio: 70, + }, + { + Index: 3, + BlockDuration: math.MaxInt64, + //NextWarmupHeight: math.MaxInt64, + WarmupRatio: 100, + }, + } +} + +// expected to be called by governance +func modifyWarmup(index int, blockDuration int64) { + if index >= len(warmupTemplate) { + panic(ufmt.Sprintf("index(%d) is out of range", index)) + } + + warmupTemplate[index].BlockDuration = blockDuration +} + +func InstantiateWarmup(currentHeight int64) []Warmup { + warmups := make([]Warmup, 0) + for _, warmup := range warmupTemplate { + nextWarmupHeight := currentHeight + warmup.BlockDuration + if nextWarmupHeight < 0 { + nextWarmupHeight = math.MaxInt64 + } + + warmups = append(warmups, Warmup{ + Index: warmup.Index, + BlockDuration: warmup.BlockDuration, + NextWarmupHeight: nextWarmupHeight, + WarmupRatio: warmup.WarmupRatio, + }) + currentHeight += warmup.BlockDuration + } + return warmups +} + +func (warmup *Warmup) Apply(poolReward uint64, positionLiquidity, stakedLiquidity *u256.Uint) (uint64, uint64) { + poolRewardUint := u256.NewUint(poolReward) + perPositionReward := u256.Zero().Mul(poolRewardUint, positionLiquidity) + perPositionReward = u256.Zero().Div(perPositionReward, stakedLiquidity) + rewardRatio := u256.NewUint(warmup.WarmupRatio) + penaltyRatio := u256.NewUint(100 - warmup.WarmupRatio) + totalReward := u256.Zero().Mul(perPositionReward, rewardRatio) + totalReward = u256.Zero().Div(totalReward, u256.NewUint(100)) + totalPenalty := u256.Zero().Mul(perPositionReward, penaltyRatio) + totalPenalty = u256.Zero().Div(totalPenalty, u256.NewUint(100)) + return totalReward.Uint64(), totalPenalty.Uint64() +} + +func (self *Deposit) FindWarmup(currentHeight int64) int { + for i, warmup := range self.warmups { + if currentHeight < warmup.NextWarmupHeight { + return i + } + } + return len(self.warmups) - 1 +} + +func (self *Deposit) GetWarmup(index int) Warmup { + return self.warmups[index] +} diff --git a/contract/r/gnoswap/staker/reward_calculation_warmup_test.gno b/contract/r/gnoswap/staker/reward_calculation_warmup_test.gno new file mode 100644 index 000000000..1f26a7619 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_calculation_warmup_test.gno @@ -0,0 +1,419 @@ +package staker + +import ( + "math" + "std" + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + + u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/gns" +) + +func TestDefaultWarmupTemplate(t *testing.T) { + warmups := DefaultWarmupTemplate() + + uassert.Equal(t, 4, len(warmups)) + + uassert.Equal(t, uint64(30), warmups[0].WarmupRatio) + uassert.Equal(t, uint64(50), warmups[1].WarmupRatio) + uassert.Equal(t, uint64(70), warmups[2].WarmupRatio) + uassert.Equal(t, uint64(100), warmups[3].WarmupRatio) + uassert.Equal(t, int64(math.MaxInt64), warmups[3].BlockDuration) +} + +func TestInstantiateWarmup(t *testing.T) { + currentHeight := int64(1000) + warmups := InstantiateWarmup(currentHeight) + + uassert.True(t, warmups[0].NextWarmupHeight > currentHeight) + uassert.True(t, warmups[1].NextWarmupHeight > warmups[0].NextWarmupHeight) +} + +func TestWarmupApply(t *testing.T) { + warmup := Warmup{ + WarmupRatio: 30, + } + + poolReward := uint64(1000) + positionLiquidity := u256.NewUint(100) + stakedLiquidity := u256.NewUint(200) + + reward, penalty := warmup.Apply(poolReward, positionLiquidity, stakedLiquidity) + + uassert.Equal(t, poolReward/2, reward+penalty) + uassert.Equal(t, uint64(150), reward) + uassert.Equal(t, uint64(350), penalty) +} + +func TestWarmupApply2(t *testing.T) { + tests := []struct { + name string + warmupRatio uint64 + poolReward uint64 + position uint64 + staked uint64 + expReward uint64 + expPenalty uint64 + }{ + { + name: "30% ratio", + warmupRatio: 30, + poolReward: 1000, + position: 100, + staked: 200, + expReward: 150, // (1000 * 100/200) * 30% = 150 + expPenalty: 350, // (1000 * 100/200) * 70% = 350 + }, + { + name: "50% ratio", + warmupRatio: 50, + poolReward: 1000, + position: 100, + staked: 200, + expReward: 250, // (1000 * 100/200) * 50% = 250 + expPenalty: 250, // (1000 * 100/200) * 50% = 250 + }, + { + name: "70% ratio", + warmupRatio: 70, + poolReward: 1000, + position: 100, + staked: 200, + expReward: 350, // (1000 * 100/200) * 70% = 350 + expPenalty: 150, // (1000 * 100/200) * 30% = 150 + }, + { + name: "100% ratio", + warmupRatio: 100, + poolReward: 1000, + position: 100, + staked: 200, + expReward: 500, // (1000 * 100/200) * 100% = 500 + expPenalty: 0, // (1000 * 100/200) * 0% = 0 + }, + { + name: "big number", + warmupRatio: 50, + poolReward: 1000000, + position: 1000, + staked: 1000, + expReward: 500000, // (1000000 * 1000/1000) * 50% = 500000 + expPenalty: 500000, // (1000000 * 1000/1000) * 50% = 500000 + }, + } + + for i, tt := range tests { + t.Run(ufmt.Sprintf("Case %d: WarmupRatio %d%%", i, tt.warmupRatio), func(t *testing.T) { + warmup := Warmup{ + WarmupRatio: tt.warmupRatio, + } + + reward, penalty := warmup.Apply( + tt.poolReward, + u256.NewUint(tt.position), + u256.NewUint(tt.staked), + ) + + uassert.Equal(t, tt.expReward, reward) + uassert.Equal(t, tt.expPenalty, penalty) + + expectedTotal := tt.poolReward * tt.position / tt.staked + uassert.Equal(t, expectedTotal, reward+penalty) + }) + } +} + +func TestWarmupBoundaryValues(t *testing.T) { + warmup := Warmup{WarmupRatio: 50} + maxReward := uint64(math.MaxUint64) + minReward := uint64(0) + + position := u256.NewUint(1) + staked := u256.NewUint(1) + + reward, penalty := warmup.Apply(maxReward, position, staked) + uassert.Equal(t, maxReward/2, reward) + + // zero reward + reward, penalty = warmup.Apply(0, position, staked) + uassert.Equal(t, uint64(0), reward) + uassert.Equal(t, uint64(0), penalty) +} + +func TestFindWarmup(t *testing.T) { + deposit := &Deposit{ + warmups: InstantiateWarmup(1000), + } + + // check index of warmup + uassert.Equal(t, 0, deposit.FindWarmup(1001)) + uassert.Equal(t, 3, deposit.FindWarmup(math.MaxInt64)) +} + +func TestLiquidityRatios(t *testing.T) { + warmup := Warmup{WarmupRatio: 50} + + // expected reward is calculated by following formula: + // + // `finalReward = (position / staked) * totalReward * warmupRatio` + tests := []struct { + name string + position uint64 + staked uint64 + reward uint64 + expected uint64 + }{ + // Case 1: position(100) / staked(200) = 0.5 + // 50% of total reward 1000 is the base reward, which is 500 + // Apply WarmupRatio 50% => 500 * 0.5 = 250 + {"50% ratio", 100, 200, 1000, 250}, + + // Case 2: position(200) / staked(100) = 2 + // Total reward 1000 * 2 = 2000 would be the base reward + // However, can only receive up to maximum 1000 + // Apply WarmupRatio 50% => 1000 * 0.5 = 1000 + {"200% ratio", 200, 100, 1000, 1000}, + + // Case 3: position(1) / staked(1000) = 0.001 + // Total reward 1000 * 0.001 = 1 is the base reward + // Apply WarmupRatio 50% to this => 1 * 0.5 = 0 (floor) + {"small ratio", 1, 1000, 1000, 0}, + + // Case 4: position(1000) / staked(1) = 1000 + // Total reward 1000 * 1000 is the base reward + // Apply WarmupRatio 50% to this => 1000000 * 0.5 = 500000 + {"large ratio", 1000, 1, 1000, 500000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reward, _ := warmup.Apply( + tt.reward, + u256.NewUint(tt.position), + u256.NewUint(tt.staked), + ) + uassert.Equal(t, tt.expected, reward) + }) + } +} + +func TestWarmupTransitions(t *testing.T) { + currentHeight := int64(1000) + deposit := &Deposit{ + warmups: InstantiateWarmup(currentHeight), + } + + transitions := []struct { + height int64 + expectedIndex int + }{ + {currentHeight - 1, 0}, + {currentHeight, 0}, + {currentHeight + 1, 0}, + {deposit.warmups[0].NextWarmupHeight - 1, 0}, + {deposit.warmups[0].NextWarmupHeight, 1}, + {deposit.warmups[1].NextWarmupHeight, 2}, + {deposit.warmups[2].NextWarmupHeight, 3}, + } + + for _, tt := range transitions { + uassert.Equal(t, tt.expectedIndex, deposit.FindWarmup(tt.height)) + } +} + +func TestModifyWarmup(t *testing.T) { + originalTemplate := DefaultWarmupTemplate() + + t.Run("modify warmup with valid index", func(t *testing.T) { + modifyWarmup(0, 1000) + uassert.Equal(t, int64(1000), warmupTemplate[0].BlockDuration) + }) + + t.Run("modify warmup with negative block duration", func(t *testing.T) { + modifyWarmup(0, -1000) + instantiated := InstantiateWarmup(0) + uassert.Equal(t, int64(math.MaxInt64), instantiated[0].NextWarmupHeight) + }) + + t.Run("modify warmup with out range index", func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Logf("Recovered from panic: %v", r) + } + }() + modifyWarmup(len(warmupTemplate), 1000) + }) +} + +func TestRewardCalculationPrecision(t *testing.T) { + warmup := Warmup{WarmupRatio: 33} // possible decimal point + + cases := []struct { + name string + reward uint64 + expectedReward uint64 + expectedPenalty uint64 + }{ + // reward: 100 + // Reward calculation: 100 * (33/100) = 33 + // Penalty calculation: 100 * (67/100) = 67 + {"small", 100, 33, 67}, + + // reward: 1000 + // Reward calculation: 1000 * (33/100) = 330 + // Penalty calculation: 1000 * (67/100) = 670 + {"medium", 1000, 330, 670}, + + // Step 1: Per Position Reward calculation + // perPositionReward = poolReward * positionLiquidity / stakedLiquidity + // perPositionReward = 999 * 1 / 1 = 999 + // + // Step 2: Reward calculation with WarmupRatio(33%) + // totalReward = perPositionReward * rewardRatio / 100 + // totalReward = 999 * 33 / 100 = 329.67 (floor => 329) + // + // Step 3: Penalty calculation with PenaltyRatio(67%) + // totalPenalty = perPositionReward * penaltyRatio / 100 + // totalPenalty = 999 * 67 / 100 = 669.33 (floor => 669) + {"division rounding", 999, 329, 669}, + } + + position := u256.NewUint(1) + staked := u256.NewUint(1) + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + reward, penalty := warmup.Apply(tt.reward, position, staked) + uassert.Equal(t, tt.expectedReward, reward) + uassert.Equal(t, tt.expectedPenalty, penalty) + // consider rounding error + uassert.True(t, math.Abs(float64(tt.reward)-float64(reward+penalty)) <= 1) + }) + } +} + +func TestMultipleWarmupPeriods(t *testing.T) { + msInDay := int64(86400000) + // Base reward calculation: poolRewardPerBlock * (position/staked) + // 100 * (1000/2000) = 50 per block + blocksInDay := msInDay / int64(gns.GetAvgBlockTimeInMs()) // 43200 + baseReward := uint64(50) + + t.Run("7 days staking period", func(t *testing.T) { + // Step 1: Per Position Reward calculation + // poolRewardPerBlock * (position/staked) = 100 * (1000/2000) = 50 per block + // + // Step 2: Rewards for first 5 days with WarmupRatio30 (30%) + // baseReward * number of blocks * warmupRatio + // 50 * (5 * blocksInDay) * 30% = expected reward for 5 days => 3240000 + // + // Step 3: Rewards for next 2 days with WarmupRatio50 (50%) + // baseReward * number of blocks * warmupRatio + // 50 * (2 * blocksInDay) * 50% = expected reward for 2 days => 2160000 + startHeight := int64(1000) + std.TestSkipHeights(startHeight) + + deposit := &Deposit{ + warmups: InstantiateWarmup(startHeight), + } + + position := u256.NewUint(1000) + staked := u256.NewUint(2000) + + var totalReward, totalPenalty uint64 + + // WarmupRatio30 (30%) 5 days + blocks5Days := uint64(5 * blocksInDay) + reward5Days, penalty5Days := deposit.warmups[0].Apply( + baseReward*blocks5Days, + position, + staked, + ) + totalReward += reward5Days + totalPenalty += penalty5Days + + // baseReward * blocks * 0.3 + expected5Days := baseReward * blocks5Days * 30 / 100 * 1000 / 2000 + uassert.Equal(t, expected5Days, reward5Days, + ufmt.Sprintf("5 days reward mismatch: expected %d, got %d", expected5Days, reward5Days)) + + // WarmupRatio50 (50%) 2 days + std.TestSkipHeights(5 * blocksInDay) + blocks2Days := uint64(2 * blocksInDay) + reward2Days, penalty2Days := deposit.warmups[1].Apply( + baseReward*blocks2Days, + position, + staked, + ) + totalReward += reward2Days + totalPenalty += penalty2Days + + // baseReward * blocks * 0.5 + expected2Days := baseReward * blocks2Days * 50 / 100 * 1000 / 2000 + uassert.Equal(t, expected2Days, reward2Days, + ufmt.Sprintf("2 days reward mismatch: expected %d, got %d", expected2Days, reward2Days)) + + expectedTotal := expected5Days + expected2Days + uassert.Equal(t, expectedTotal, totalReward, + ufmt.Sprintf("Total reward mismatch: expected %d, got %d", expectedTotal, totalReward)) + }) + + t.Run("40 days staking period", func(t *testing.T) { + // Reward calculation for 40 days period with different warmup ratios + // + // Step 1: Base reward per block + // poolRewardPerBlock * (position/staked) = 100 * (1000/2000) = 50 per block + // + // Step 2: Period-specific rewards + // Days 0-5: 50 * (5 * blocksInDay) * 30% = reward for first 5 days => 2430000 + // Days 6-10: 50 * (5 * blocksInDay) * 50% = reward for next 5 days => 5400000 + // Days 11-30: 50 * (20 * blocksInDay) * 70% = reward for next 20 days => 30240000 + // Days 31-40: 50 * (10 * blocksInDay) * 100% = reward for final 10 days => 21600000 + // + // Step 3: Total reward is sum of all period rewards + startHeight := int64(2000) + std.TestSkipHeights(startHeight) + + deposit := &Deposit{ + warmups: InstantiateWarmup(startHeight), + } + + position := u256.NewUint(1000) + staked := u256.NewUint(2000) + + periods := []struct { + days int64 + warmupRatio uint64 + }{ + {5, 30}, // 0-5 days + {5, 50}, // 6-10 days + {20, 70}, // 11-30 days + {10, 100}, // 31-40 days + } + + var totalReward, totalPenalty uint64 + currentHeight := startHeight + + for i, period := range periods { + blocks := uint64(period.days * blocksInDay) + reward, penalty := deposit.warmups[i].Apply( + baseReward*blocks, + position, + staked, + ) + totalReward += reward + totalPenalty += penalty + + expectedReward := baseReward * blocks * period.warmupRatio / 100 * 1000 / 2000 + uassert.Equal(t, expectedReward, reward, + ufmt.Sprintf("Period %d reward mismatch: expected %d, got %d", + i, expectedReward, reward)) + + std.TestSkipHeights(period.days * blocksInDay) + currentHeight += period.days * blocksInDay + } + }) +} diff --git a/contract/r/gnoswap/staker/reward_pool_store.gno b/contract/r/gnoswap/staker/reward_pool_store.gno new file mode 100644 index 000000000..c4ec439f7 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_pool_store.gno @@ -0,0 +1,186 @@ +package staker + +const ( + TIER1_INDEX = 0 + TIER2_INDEX = 1 + TIER3_INDEX = 2 + TIER_TYPE_NUM = 3 +) + +// RewardPool represents a reward pool. +// A RewardPool type has the following fields: +// - tier: tier of pool +// - rewardAmount: distributable reward amount +// - distributedAmount: distributed reward amount +// - leftAmount: left reward amount +type RewardPool struct { + tier uint64 + rewardAmount uint64 + distributedAmount uint64 + leftAmount uint64 +} + +// NewRewardPool creates a new instance of RewardPool with default values. +// +// Returns: +// - *RewardPool: A pointer to the initialized RewardPool instance. +func NewRewardPool() *RewardPool { + return &RewardPool{ + tier: 0, + rewardAmount: 0, + distributedAmount: 0, + leftAmount: 0, + } +} + +// SetTier : Sets the tier of the pool. +func (r *RewardPool) SetTier(tier uint64) { + r.tier = tier +} + +// SetRewardAmount : Sets the total reward amount for the pool. +func (r *RewardPool) SetRewardAmount(rewardAmount uint64) { + r.rewardAmount = rewardAmount +} + +// SetDistributedAmount : Sets the amount of rewards distributed. +func (r *RewardPool) SetDistributedAmount(distributedAmount uint64) { + r.distributedAmount = distributedAmount +} + +// SetLeftAmount : Sets the remaining rewards in the pool. +func (r *RewardPool) SetLeftAmount(leftAmount uint64) { + r.leftAmount = leftAmount +} + +// GetTier : Returns the tier of the pool. +func (r *RewardPool) GetTier() uint64 { + return r.tier +} + +// GetRewardAmount : Returns the total reward amount. +func (r *RewardPool) GetRewardAmount() uint64 { + return r.rewardAmount +} + +// GetDistributedAmount : Returns the distributed reward amount. +func (r *RewardPool) GetDistributedAmount() uint64 { + return r.distributedAmount +} + +// GetLeftAmount : Returns the remaining reward amount. +func (r *RewardPool) GetLeftAmount() uint64 { + return r.leftAmount +} + +// RewardPoolMap manages multiple reward pools and tier-specific reward data. +// +// Fields: +// - rewardPools (map[string]*RewardPool): Maps pool paths to RewardPool instances. +// - rewardAmountForTier ([TIER_TYPE_NUM]uint64): Total reward amount for each tier. +// - rewardAmountForTierEachPool ([TIER_TYPE_NUM]uint64): Reward amount for each pool in each tier. +// - leftAmountForTier ([TIER_TYPE_NUM]uint64): Remaining reward amount for each tier. +type RewardPoolMap struct { + rewardPools map[string]*RewardPool // poolPath -> RewardPool + rewardAmountForTier [TIER_TYPE_NUM]uint64 // total reward amount for each tier + rewardAmountForTierEachPool [TIER_TYPE_NUM]uint64 // reward amount for each pool in each tier + leftAmountForTier [TIER_TYPE_NUM]uint64 // left reward amount for each tier +} + +// NewRewardPoolMap creates a new instance of RewardPoolMap with default values. +// +// Returns: +// - *RewardPoolMap: A pointer to the initialized RewardPoolMap instance. +func NewRewardPoolMap() *RewardPoolMap { + return &RewardPoolMap{ + rewardPools: make(map[string]*RewardPool), + rewardAmountForTier: [TIER_TYPE_NUM]uint64{}, + rewardAmountForTierEachPool: [TIER_TYPE_NUM]uint64{}, + leftAmountForTier: [TIER_TYPE_NUM]uint64{}, + } +} + +// SetRewardPoolMap : Replaces the current reward pool map. +func (r *RewardPoolMap) SetRewardPoolMap(rewardPools map[string]*RewardPool) { + r.rewardPools = rewardPools +} + +// SetRewardPool : Adds or updates a reward pool for a specific pool path. +func (r *RewardPoolMap) SetRewardPool(poolPath string, rewardPool *RewardPool) { + r.rewardPools[poolPath] = rewardPool +} + +// SetPoolTier : Sets the tier of a specific pool. +func (r *RewardPoolMap) SetPoolTier(poolPath string, tier uint64) { + if _, exist := r.rewardPools[poolPath]; !exist { + r.rewardPools[poolPath] = NewRewardPool() + } + r.rewardPools[poolPath].SetTier(tier) +} + +// SetPoolRewardAmount : Sets the reward amount for a specific pool. +func (r *RewardPoolMap) SetPoolRewardAmount(poolPath string, rewardAmount uint64) { + if _, exist := r.rewardPools[poolPath]; !exist { + r.rewardPools[poolPath] = NewRewardPool() + } + r.rewardPools[poolPath].SetRewardAmount(rewardAmount) +} + +// SetRewardAmountForTier : Sets the total reward amount for a tier. +func (r *RewardPoolMap) SetRewardAmountForTier(tierIndex int, amount uint64) { + r.rewardAmountForTier[tierIndex] = amount +} + +// SetRewardAmountForTierEachPool : Sets the reward amount per pool for a tier. +func (r *RewardPoolMap) SetRewardAmountForTierEachPool(tierIndex int, amount uint64) { + r.rewardAmountForTierEachPool[tierIndex] = amount +} + +// SetLeftAmountForTier : Sets the remaining reward amount for a tier. +func (r *RewardPoolMap) SetLeftAmountForTier(tierIndex int, amount uint64) { + r.leftAmountForTier[tierIndex] = amount +} + +// GetRewardPools : Returns the entire reward pool map. +func (r *RewardPoolMap) GetRewardPools() map[string]*RewardPool { + return r.rewardPools +} + +// GetRewardPoolByPoolPath : Returns the RewardPool for a specific pool path. +func (r *RewardPoolMap) GetRewardPoolByPoolPath(poolPath string) *RewardPool { + if _, exist := r.rewardPools[poolPath]; !exist { + r.rewardPools[poolPath] = NewRewardPool() + } + return r.rewardPools[poolPath] +} + +// GetPoolTier : Returns the tier of a specific pool. +func (r *RewardPoolMap) GetPoolTier(poolPath string) uint64 { + if _, exist := r.rewardPools[poolPath]; !exist { + r.rewardPools[poolPath] = NewRewardPool() + } + return r.rewardPools[poolPath].tier +} + +// GetPoolRewardAmount : Returns the reward amount for a specific pool. +func (r *RewardPoolMap) GetPoolRewardAmount(poolPath string) uint64 { + if _, exist := r.rewardPools[poolPath]; !exist { + r.rewardPools[poolPath] = NewRewardPool() + } + return r.rewardPools[poolPath].rewardAmount +} + +// GetRewardAmountForTier : Returns the total reward amount for a tier. +func (r *RewardPoolMap) GetRewardAmountForTier(tierIndex int) uint64 { + return r.rewardAmountForTier[tierIndex] +} + +// GetRewardAmountForTierEachPool : Returns the reward amount per pool for a tier. +func (r *RewardPoolMap) GetRewardAmountForTierEachPool(tierIndex int) uint64 { + return r.rewardAmountForTierEachPool[tierIndex] +} + +// GetLeftAmountForTier : Returns the remaining reward amount for a tier. +func (r *RewardPoolMap) GetLeftAmountForTier(tierIndex int) uint64 { + return r.leftAmountForTier[tierIndex] +} diff --git a/contract/r/gnoswap/staker/reward_pool_store_test.gno b/contract/r/gnoswap/staker/reward_pool_store_test.gno new file mode 100644 index 000000000..b371d5a68 --- /dev/null +++ b/contract/r/gnoswap/staker/reward_pool_store_test.gno @@ -0,0 +1,78 @@ +package staker + +import "testing" + +func TestRewardPool(t *testing.T) { + pool := NewRewardPool() + + // Test default values + if pool.GetTier() != 0 { + t.Errorf("Expected default tier to be 0, got %d", pool.GetTier()) + } + if pool.GetRewardAmount() != 0 { + t.Errorf("Expected default rewardAmount to be 0, got %d", pool.GetRewardAmount()) + } + if pool.GetDistributedAmount() != 0 { + t.Errorf("Expected default distributedAmount to be 0, got %d", pool.GetDistributedAmount()) + } + if pool.GetLeftAmount() != 0 { + t.Errorf("Expected default leftAmount to be 0, got %d", pool.GetLeftAmount()) + } + + // Test setters + pool.SetTier(1) + pool.SetRewardAmount(1000) + pool.SetDistributedAmount(500) + pool.SetLeftAmount(500) + + // Test getters + if pool.GetTier() != 1 { + t.Errorf("Expected tier to be 1, got %d", pool.GetTier()) + } + if pool.GetRewardAmount() != 1000 { + t.Errorf("Expected rewardAmount to be 1000, got %d", pool.GetRewardAmount()) + } + if pool.GetDistributedAmount() != 500 { + t.Errorf("Expected distributedAmount to be 500, got %d", pool.GetDistributedAmount()) + } + if pool.GetLeftAmount() != 500 { + t.Errorf("Expected leftAmount to be 500, got %d", pool.GetLeftAmount()) + } +} + +func TestRewardPoolMap(t *testing.T) { + rewardMap := NewRewardPoolMap() + + // Test default values + if len(rewardMap.GetRewardPools()) != 0 { + t.Errorf("Expected empty rewardPools map, got %d entries", len(rewardMap.GetRewardPools())) + } + + // Test adding a new pool + rewardMap.SetRewardPool("pool1", &RewardPool{tier: 1, rewardAmount: 1000}) + rewardPool := rewardMap.GetRewardPoolByPoolPath("pool1") + if rewardPool.GetTier() != 1 { + t.Errorf("Expected tier of pool1 to be 1, got %d", rewardPool.GetTier()) + } + if rewardPool.GetRewardAmount() != 1000 { + t.Errorf("Expected rewardAmount of pool1 to be 1000, got %d", rewardPool.GetRewardAmount()) + } + + // Test updating pool tier + rewardMap.SetPoolTier("pool1", 2) + if rewardMap.GetPoolTier("pool1") != 2 { + t.Errorf("Expected updated tier of pool1 to be 2, got %d", rewardMap.GetPoolTier("pool1")) + } + + // Test setting and getting reward amount for tiers + rewardMap.SetRewardAmountForTier(TIER1_INDEX, 5000) + if rewardMap.GetRewardAmountForTier(TIER1_INDEX) != 5000 { + t.Errorf("Expected rewardAmountForTier[TIER1_INDEX] to be 5000, got %d", rewardMap.GetRewardAmountForTier(TIER1_INDEX)) + } + + // Test setting and getting left amount for tiers + rewardMap.SetLeftAmountForTier(TIER2_INDEX, 3000) + if rewardMap.GetLeftAmountForTier(TIER2_INDEX) != 3000 { + t.Errorf("Expected leftAmountForTier[TIER2_INDEX] to be 3000, got %d", rewardMap.GetLeftAmountForTier(TIER2_INDEX)) + } +} diff --git a/contract/r/gnoswap/staker/staker.gno b/contract/r/gnoswap/staker/staker.gno new file mode 100644 index 000000000..89bafe359 --- /dev/null +++ b/contract/r/gnoswap/staker/staker.gno @@ -0,0 +1,704 @@ +package staker + +import ( + "std" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/gnoswap/v1/gnft" + "gno.land/r/gnoswap/v1/gns" + + en "gno.land/r/gnoswap/v1/emission" + pl "gno.land/r/gnoswap/v1/pool" + pn "gno.land/r/gnoswap/v1/position" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +type Deposits struct { + tree *avl.Tree +} + +func NewDeposits() *Deposits { + return &Deposits{ + tree: avl.NewTree(), // tokenId -> *Deposit + } +} + +func (self *Deposits) Get(tokenId uint64) *Deposit { + depositI, ok := self.tree.Get(EncodeUint(tokenId)) + if !ok { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("tokenId(%d) not found", tokenId), + )) + } + deposit := depositI.(*Deposit) + return deposit +} + +func (self *Deposits) Set(tokenId uint64, deposit *Deposit) { + self.tree.Set(EncodeUint(tokenId), deposit) +} + +func (self *Deposits) Has(tokenId uint64) bool { + return self.tree.Has(EncodeUint(tokenId)) +} + +func (self *Deposits) Remove(tokenId uint64) { + self.tree.Remove(EncodeUint(tokenId)) +} + +func (self *Deposits) Iterate(start uint64, end uint64, fn func(tokenId uint64, deposit *Deposit) bool) { + self.tree.Iterate(EncodeUint(start), EncodeUint(end), func(tokenId string, depositI interface{}) bool { + return fn(DecodeUint(tokenId), depositI.(*Deposit)) + }) +} + +func (self *Deposits) Size() int { + return self.tree.Size() +} + +type ExternalIncentives struct { + tree *avl.Tree +} + +func NewExternalIncentives() *ExternalIncentives { + return &ExternalIncentives{ + tree: avl.NewTree(), + } +} + +func (self *ExternalIncentives) Get(incentiveId string) *ExternalIncentive { + incentiveI, ok := self.tree.Get(incentiveId) + if !ok { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("incentiveId(%s) not found", incentiveId), + )) + } + + incentive := incentiveI.(*ExternalIncentive) + return incentive +} + +func (self *ExternalIncentives) Set(incentiveId string, incentive *ExternalIncentive) { + self.tree.Set(incentiveId, incentive) +} + +func (self *ExternalIncentives) Has(incentiveId string) bool { + return self.tree.Has(incentiveId) +} + +func (self *ExternalIncentives) Remove(incentiveId string) { + self.tree.Remove(incentiveId) +} + +func (self *ExternalIncentives) Size() int { + return self.tree.Size() +} + +type Stakers struct { + tree *avl.Tree // address -> depositId -> *Deposit +} + +func NewStakers() *Stakers { + return &Stakers{ + tree: avl.NewTree(), + } +} + +func (self *Stakers) IterateAll(address std.Address, fn func(depositId uint64, deposit *Deposit) bool) { + depositTreeI, ok := self.tree.Get(address.String()) + if !ok { + return + } + depositTree := depositTreeI.(*avl.Tree) + depositTree.Iterate("", "", func(depositId string, depositI interface{}) bool { + deposit := depositI.(*Deposit) + return fn(DecodeUint(depositId), deposit) + }) +} + +func (self *Stakers) AddDeposit(address std.Address, depositId uint64, deposit *Deposit) { + depositTreeI, ok := self.tree.Get(address.String()) + if !ok { + depositTree := avl.NewTree() + self.tree.Set(address.String(), depositTree) + depositTreeI = depositTree + } + depositTree := depositTreeI.(*avl.Tree) + depositTree.Set(EncodeUint(depositId), deposit) +} + +func (self *Stakers) RemoveDeposit(address std.Address, depositId uint64) { + depositTreeI, ok := self.tree.Get(address.String()) + if !ok { + return + } + depositTree := depositTreeI.(*avl.Tree) + depositTree.Remove(EncodeUint(depositId)) +} + +var ( + // deposits stores deposit information for each tokenId + deposits *Deposits = NewDeposits() + + // externalIncentives stores external incentive information for each incentiveId + externalIncentives *ExternalIncentives = NewExternalIncentives() + + // stakers stores staker information for each address + stakers *Stakers = NewStakers() + + // poolTier stores pool tier information + poolTier *PoolTier + + // totalEmissionSent is the total amount of GNS emission sent from staker to user(and community pool if penalty exists) + // which includes following + // 1. reward sent to user (which also includes protocol_fee) + // 2. penalty sent to community pool + // 3. unclaimable reward + totalEmissionSent uint64 +) + +const ( + TIMESTAMP_90DAYS = 7776000 + TIMESTAMP_180DAYS = 15552000 + TIMESTAMP_365DAYS = 31536000 + + MAX_UNIX_EPOCH_TIME = 253402300799 // 9999-12-31 23:59:59 + + MUST_EXISTS_IN_TIER_1 = "gno.land/r/demo/wugnot:gno.land/r/gnoswap/v1/gns:3000" + + INTERNAL = true + EXTERNAL = false +) + +func init() { + // init pool tiers + // tier 1 + // ONLY GNOT:GNS 0.3% + + pools.GetOrCreate(MUST_EXISTS_IN_TIER_1) + poolTier = NewPoolTier(pools, std.GetHeight(), MUST_EXISTS_IN_TIER_1, en.GetEmission, en.GetHalvingBlocksInRange) + en.SetCallbackStakerEmissionChange(callbackStakerEmissionChange) +} + +// StakeToken stakes an LP token into the staker contract. It transfer the LP token +// ownership to the staker contract. +// +// State Transition: +// 1. Token ownership transfers from user -> staker contract +// 2. Position operator changes to caller +// 3. Deposit record is created and stored +// 4. Internal warm up amount is set to 0 +// +// Requirements: +// 1. Token must have non-zero liquidity +// 2. Pool must have either internal or external incentives +// 3. Caller must be token owner or approved operator +// +// Parameters: +// - tokenId (uint64): The ID of the LP token to stake +// +// Returns: +// - poolPath (string): The path of the pool to which the LP token is staked +// - token0Amount (string): The amount of token0 in the LP token +// - token1Amount (string): The amount of token1 in the LP token +// +// ref: https://docs.gnoswap.io/contracts/staker/staker.gno#staketoken +func StakeToken(tokenId uint64) (string, string, string) { + assertOnlyNotHalted() + assertOnlyNotStaked(tokenId) + + en.MintAndDistributeGns() + + owner := gnft.MustOwnerOf(tid(tokenId)) + caller := getPrevAddr() + + token0Amount, token1Amount, err := calculateStakeTokenAmount(tokenId, owner, caller) + if err != nil { + panic(err.Error()) + } + + // check pool path from tokenId + poolPath := pn.PositionGetPositionPoolKey(tokenId) + pool, ok := pools.Get(poolPath) + if !ok { + panic(addDetailToError( + errNonIncentivizedPool, + ufmt.Sprintf("can not stake position to non existing pool(%s)", poolPath), + )) + } + currentHeight := std.GetHeight() + liquidity := getLiquidity(tokenId) + + tickLower, tickUpper := getTickOf(tokenId) + + // staked status + deposit := &Deposit{ + owner: caller, + stakeTimestamp: time.Now().Unix(), + stakeHeight: currentHeight, + targetPoolPath: poolPath, + tickLower: tickLower, + tickUpper: tickUpper, + liquidity: liquidity, + lastCollectHeight: currentHeight, + warmups: InstantiateWarmup(currentHeight), + } + + currentTick := pl.PoolGetSlot0Tick(poolPath) + + deposits.Set(tokenId, deposit) + stakers.AddDeposit(caller, tokenId, deposit) + + if caller == owner { // if caller is owner, transfer NFT ownership to staker contract + if err := transferDeposit(tokenId, owner, caller, consts.STAKER_ADDR); err != nil { + panic(err.Error()) + } + } + + // after transfer, set caller(user) as position operator (to collect fee and reward) + pn.SetPositionOperator(tokenId, caller) + + signedLiquidity := i256.FromUint256(liquidity) + isInRange := false + + poolTier.cacheReward(currentHeight, pools) + + if pn.PositionIsInRange(tokenId) { + isInRange = true + pool.modifyDeposit(signedLiquidity, currentHeight, currentTick) + } + // historical tick must be set regardless of the deposit's range + pool.historicalTick.Set(currentHeight, currentTick) + + // this could happen because of how position stores the ticks. + // ticks are negated if the token1 < token0 + upperTick := pool.ticks.Get(tickUpper) + lowerTick := pool.ticks.Get(tickLower) + + upperTick.modifyDepositUpper(currentHeight, currentTick, signedLiquidity) + lowerTick.modifyDepositLower(currentHeight, currentTick, signedLiquidity) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "StakeToken", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "positionId", formatUint(tokenId), + "poolPath", poolPath, + "amount0", token0Amount, + "amount1", token1Amount, + "liquidity", liquidity.ToString(), + "positionUpperTick", formatInt(tickUpper), + "positionLowerTick", formatInt(tickLower), + "currentTick", formatInt(currentTick), + "isInRange", formatBool(isInRange), + ) + + // positionsInternalWarmUpAmount[tokenId] = warmUpAmount{} + return poolPath, token0Amount, token1Amount +} + +// calculateStakeData validates staking requirements and prepares staking data. +// +// It checks if the token is already staked, verifies ownership, and ensures the pool has incentives. +// If successful, it returns the staking data; otherwise, it returns an error. +// +// Parameters: +// - tokenId: The ID of the LP token to stake +// - owner: The owner of the LP token +// - caller: The caller of the staking operation +// +// Returns: +// - *stakeResult: The staking data if successful +// - error: An error if any validation fails +func calculateStakeTokenAmount(tokenId uint64, owner, caller std.Address) (string, string, error) { + exist := deposits.Has(tokenId) + if exist { + return "", "", errAlreadyStaked + } + + if err := requireTokenOwnership(owner, caller); err != nil { + return "", "", err + } + + if err := tokenHasLiquidity(tokenId); err != nil { + return "", "", err + } + + poolPath := pn.PositionGetPositionPoolKey(tokenId) + if err := poolHasIncentives(poolPath); err != nil { + return "", "", err + } + + token0Amount, token1Amount := getTokenPairBalanceFromPosition(poolPath, tokenId) + + return token0Amount, token1Amount, nil +} + +// transferDeposit transfers the ownership of a deposit token (NFT) to a new owner. +// +// This function ensures that the caller is not the same as the recipient (`to`). +// If the caller is not the owner or the recipient, it attempts to transfer the NFT +// ownership from the current owner to the recipient. +// +// Parameters: +// - tokenId (uint64): The unique identifier of the token (NFT) to transfer. +// - owner (std.Address): The current owner of the token. +// - caller (std.Address): The address attempting to initiate the transfer. +// - to (std.Address): The address to which the token will be transferred. +// +// Returns: +// - error: Returns an error if the caller is the same as the recipient (`to`). +// Otherwise, it delegates the transfer to `gnft.TransferFrom` and returns any error +// that may occur during the transfer. +func transferDeposit(tokenId uint64, owner, caller, to std.Address) error { + if caller == to { + return ufmt.Errorf( + "%v: only owner(%s) can transfer tokenId(%d), called from %s", + errNoPermission, owner, tokenId, caller, + ) + } + // transfer NFT ownership + return gnft.TransferFrom(owner, to, tid(tokenId)) +} + +//////////////////////////////////////////////////////////// + +// CollectReward harvests accumulated rewards for a staked position. This includes both +// internal GNS emission and external incentive rewards. +// +// State Transition: +// 1. Warm-up amounts are clears for both internal and external rewards +// 2. Reward tokens are transferred to the owner +// 3. Penalty fees are transferred to protocol/community addresses +// 4. GNS balance is recalculated +// +// Requirements: +// - Contract must not be halted +// - Caller must be the position owner +// - Position must be staked (have a deposit record) +// +// Parameters: +// - tokenId (uint64): The ID of the LP token to collect rewards from +// - unwrapResult (bool): Whether to unwrap WUGNOT to GNOT +// +// Returns: +// - poolPath (string): The path of the pool to which the LP token is staked +// +// ref: https://docs.gnoswap.io/contracts/staker/staker.gno#collectreward +func CollectReward(tokenId uint64, unwrapResult bool) (string, string) { + assertOnlyNotHalted() + + deposit := deposits.Get(tokenId) + caller := getPrevAddr() + if err := common.SatisfyCond(caller == deposit.owner); err != nil { + panic(addDetailToError(errNoPermission, ufmt.Sprintf("caller is not owner of tokenId(%d)", tokenId))) + } + + en.MintAndDistributeGns() + + currentHeight := std.GetHeight() + // get all internal and external rewards + reward := calcPositionReward(currentHeight, tokenId) + + // update lastCollectHeight to current height + deposit.lastCollectHeight = currentHeight + + prevAddr, prevPkgPath := getPrev() + + // transfer external rewards to user + externalReward := reward.External + for incentiveId, amount := range externalReward { + incentive := externalIncentives.Get(incentiveId) + rewardToken := incentive.rewardToken + if incentive.rewardAmount < amount { + panic(addDetailToError( + errInsufficientReward, + ufmt.Sprintf("incentiveId(%s) has insufficient reward(%d)", incentiveId, amount), + )) + } + incentive.rewardAmount -= amount + externalIncentives.Set(incentiveId, incentive) + + toUser := handleUnStakingFee(rewardToken, amount, false, tokenId, incentive.targetPoolPath) + + teller := common.GetTokenTeller(rewardToken) + teller.Transfer(deposit.owner, toUser) + + externalPenalty := reward.ExternalPenalty[incentiveId] + incentive.rewardLeft += externalPenalty + + // unwrap if necessary + if unwrapResult && rewardToken == consts.WUGNOT_PATH { + unwrap(toUser) + } + std.Emit( + "CollectReward", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "positionId", formatUint(tokenId), + "poolPath", deposit.targetPoolPath, + "recipient", deposit.owner.String(), + "incentiveId", incentiveId, + "rewardToken", rewardToken, + "rewardAmount", formatUint(amount), + "rewardToUser", formatUint(toUser), + "rewardToFee", formatUint(amount-toUser), + "rewardPenalty", formatUint(externalPenalty), + "isRequestUnwrap", formatBool(unwrapResult), + ) + } + + // internal reward to user + toUser := handleUnStakingFee(consts.GNS_PATH, reward.Internal, true, tokenId, deposit.targetPoolPath) + + if toUser > 0 { + // internal reward to user + totalEmissionSent += toUser + gns.Transfer(deposit.owner, toUser) + + // internal penalty to community pool + totalEmissionSent += reward.InternalPenalty + gns.Transfer(consts.COMMUNITY_POOL_ADDR, reward.InternalPenalty) + } + + unClaimableInternal := ProcessUnClaimableReward(deposit.targetPoolPath, currentHeight) + if unClaimableInternal > 0 { + // internal unClaimable to community pool + totalEmissionSent += unClaimableInternal + gns.Transfer(consts.COMMUNITY_POOL_ADDR, unClaimableInternal) + } + + std.Emit( + "CollectReward", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "positionId", formatUint(tokenId), + "poolPath", deposit.targetPoolPath, + "recipient", deposit.owner.String(), + "rewardToken", consts.GNS_PATH, + "rewardAmount", formatUint(reward.Internal), + "rewardToUser", formatUint(toUser), + "rewardToFee", formatUint(reward.Internal-toUser), + "rewardPenalty", formatUint(reward.InternalPenalty), + "rewardUnClaimableAmount", formatUint(unClaimableInternal), + ) + + return formatUint(toUser), formatUint(reward.InternalPenalty) +} + +// UnStakeToken withdraws an LP token from staking, collecting all pending rewards +// and returning the token to its original owner. +// +// State transitions: +// 1. All pending rewards are collected (calls CollectReward) +// 2. Token ownership transfers back to original owner +// 3. Position operator is cleared +// 4. All staking state is cleaned up: +// - Deposit record removed +// - Position GNS balances cleared +// - Warm-up amounts cleared +// - Position removed from reward tracking +// +// Requirements: +// - Contract must not be halted +// - Position must be staked (have deposit record) +// - Rewards are automatically collected before unStaking +// +// Params: +// - tokenId (uint64): ID of the staked LP token +// - unwrapResult (bool): If true, unwraps any WUGNOT rewards to GNOT +// +// Returns: +// - poolPath (string): The pool path associated with the unstaked position +// - token0Amount (string): Final amount of token0 in the position +// - token1Amount (string): Final amount of token1 in the position +// +// ref: https://docs.gnoswap.io/contracts/staker/staker.gno#unstaketoken +func UnStakeToken(tokenId uint64, unwrapResult bool) (string, string, string) { // poolPath, token0Amount, token1Amount + assertOnlyNotHalted() + + // unStaked status + deposit := deposits.Get(tokenId) + poolPath := deposit.targetPoolPath + + // Claim All Rewards + CollectReward(tokenId, unwrapResult) + token0Amount, token1Amount := getTokenPairBalanceFromPosition(poolPath, tokenId) + + applyUnStake(tokenId) + + // transfer NFT ownership to origin owner + gnft.TransferFrom(consts.STAKER_ADDR, deposit.owner, tid(tokenId)) + pn.SetPositionOperator(tokenId, consts.ZERO_ADDRESS) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "UnStakeToken", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "positionId", formatUint(tokenId), + "poolPath", poolPath, + "isRequestUnwrap", formatBool(unwrapResult), + "from", consts.STAKER_ADDR.String(), + "to", deposit.owner.String(), + "amount0", token0Amount, + "amount1", token1Amount, + ) + + return poolPath, token0Amount, token1Amount +} + +func applyUnStake(tokenId uint64) { + deposit := deposits.Get(tokenId) + pool, ok := pools.Get(deposit.targetPoolPath) + if !ok { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("pool(%s) does not exist", deposit.targetPoolPath), + )) + } + + currentHeight := std.GetHeight() + currentTick := pl.PoolGetSlot0Tick(deposit.targetPoolPath) + signedLiquidity := i256.FromUint256(deposit.liquidity) + signedLiquidity = signedLiquidity.Neg(signedLiquidity) + if pn.PositionIsInRange(tokenId) { + pool.modifyDeposit(signedLiquidity, currentHeight, currentTick) + } + + upperTick := pool.ticks.Get(deposit.tickUpper) + lowerTick := pool.ticks.Get(deposit.tickLower) + upperTick.modifyDepositUpper(currentHeight, currentTick, signedLiquidity) + lowerTick.modifyDepositLower(currentHeight, currentTick, signedLiquidity) + + deposits.Remove(tokenId) + stakers.RemoveDeposit(deposit.owner, tokenId) + + owner := gnft.MustOwnerOf(tid(tokenId)) + caller := getPrevAddr() + + token0Amount, token1Amount, err := calculateStakeTokenAmount(tokenId, owner, caller) + if err != nil { + panic(err.Error()) + } + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "UnStakeToken", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "positionId", formatUint(tokenId), + "poolPath", deposit.targetPoolPath, + "from", GetOrigPkgAddr().String(), + "to", deposit.owner.String(), + "amount0", token0Amount, + "amount1", token1Amount, + "liquidity", deposit.liquidity.ToString(), + "positionUpperTick", formatInt(upperTick), + "positionLowerTick", formatInt(lowerTick), + "currentTick", formatInt(currentTick), + ) +} + +// requireTokenOwnership validates that the caller has permission to operate the token. +func requireTokenOwnership(owner, caller std.Address) error { + callerIsOwner := owner == caller + stakerIsOwner := owner == consts.STAKER_ADDR + + if err := common.SatisfyCond(callerIsOwner || stakerIsOwner); err != nil { + return errNoPermission + } + + return nil +} + +// poolHasIncentives checks if the pool has any active incentives (internal or external). +func poolHasIncentives(poolPath string) error { + pool, ok := pools.Get(poolPath) + if !ok { + return ufmt.Errorf( + "%v: can not stake position to non existent pool(%s)", + errNonIncentivizedPool, poolPath, + ) + } + hasInternal := poolTier.IsInternallyIncentivizedPool(poolPath) + hasExternal := pool.IsExternallyIncentivizedPool() + if hasInternal == false && hasExternal == false { + return ufmt.Errorf( + "%v: can not stake position to non incentivized pool(%s)", + errNonIncentivizedPool, poolPath, + ) + } + return nil +} + +// tokenHasLiquidity checks if the target tokenId has non-zero liquidity +func tokenHasLiquidity(tokenId uint64) error { + liquidity := getLiquidity(tokenId) + + if liquidity.Lte(u256.Zero()) { + return ufmt.Errorf( + "%v: tokenId(%d) has no liquidity", + errZeroLiquidity, tokenId, + ) + } + return nil +} + +func getLiquidity(tokenId uint64) *u256.Uint { + liq := pn.PositionGetPositionLiquidityStr(tokenId) + return u256.MustFromDecimal(liq) +} + +func assertOnlyNotStaked(tokenId uint64) { + if deposits.Has(tokenId) { + panic(addDetailToError( + errAlreadyStaked, + ufmt.Sprintf("tokenId(%d) already staked", tokenId), + )) + } +} + +func getTokenPairBalanceFromPosition(poolPath string, tokenId uint64) (string, string) { + pool := pl.GetPoolFromPoolPath(poolPath) + + currentX96 := pool.Slot0SqrtPriceX96() + lowerX96 := common.TickMathGetSqrtRatioAtTick(pn.PositionGetPositionTickLower(tokenId)) + upperX96 := common.TickMathGetSqrtRatioAtTick(pn.PositionGetPositionTickUpper(tokenId)) + + token0Balance, token1Balance := common.GetAmountsForLiquidity( + currentX96, + lowerX96, + upperX96, + u256.MustFromDecimal(pn.PositionGetPositionLiquidityStr(tokenId)), + ) + + if token0Balance == "" { + token0Balance = "0" + } + if token1Balance == "" { + token1Balance = "0" + } + return token0Balance, token1Balance +} + +func getTickOf(tokenId uint64) (int32, int32) { + tickLower := pn.PositionGetPositionTickLower(tokenId) + tickUpper := pn.PositionGetPositionTickUpper(tokenId) + if tickUpper < tickLower { + panic(ufmt.Sprintf("tickUpper(%d) is less than tickLower(%d)", tickUpper, tickLower)) + } + return tickLower, tickUpper +} diff --git a/contract/r/gnoswap/staker/staker_external_incentive.gno b/contract/r/gnoswap/staker/staker_external_incentive.gno new file mode 100644 index 000000000..6d103baa1 --- /dev/null +++ b/contract/r/gnoswap/staker/staker_external_incentive.gno @@ -0,0 +1,300 @@ +package staker + +import ( + "std" + "strconv" + "time" + + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" + + "gno.land/r/gnoswap/v1/gns" + + en "gno.land/r/gnoswap/v1/emission" + pl "gno.land/r/gnoswap/v1/pool" +) + +// CreateExternalIncentive creates an incentive program for a pool. +// ref: https://docs.gnoswap.io/contracts/staker/staker.gno#createexternalincentive +func CreateExternalIncentive( + targetPoolPath string, + rewardToken string, // token path should be registered + rewardAmount uint64, + startTimestamp int64, + endTimestamp int64, +) { + assertOnlyNotHalted() + + if common.GetLimitCaller() { + prev := std.PrevRealm() + if err := common.UserOnly(prev); err != nil { + panic(ufmt.Sprintf("%v: %v", errNoPermission, err)) + } + } + + // panic if pool does not exist + if !pl.DoesPoolPathExist(targetPoolPath) { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("targetPoolPath(%s) does not exist", targetPoolPath), + )) + } + + en.MintAndDistributeGns() + + // check token can be used as reward + if err := isAllowedForExternalReward(targetPoolPath, rewardToken); err != nil { + panic(err.Error()) + } + + // native ugnot check + if rewardToken == consts.GNOT { + sent := std.GetOrigSend() + ugnotSent := uint64(sent.AmountOf("ugnot")) + + if ugnotSent != rewardAmount { + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("user(%s) sent ugnot(%d) amount not equal to rewardAmount(%d)", getPrevAddr(), ugnotSent, rewardAmount), + )) + } + + wrap(ugnotSent) + + rewardToken = consts.WUGNOT_PATH + } + + // must be in seconds format, not milliseconds + // must be at least +1 day midnight + // must be midnight of the day + if err := checkStartTime(startTimestamp); err != nil { + panic(err.Error()) + } + + // endTimestamp cannot be later than 253402300799 (9999-12-31 23:59:59) + if endTimestamp >= MAX_UNIX_EPOCH_TIME { + panic(addDetailToError( + errInvalidIncentiveEndTime, + ufmt.Sprintf("staker.gno__CreateExternalIncentive() || endTimestamp(%d) cannot be later than 253402300799 (9999-12-31 23:59:59)", endTimestamp), + )) + } + + caller := getPrevAddr() + + // incentiveId := incentiveIdCompute(std.PrevRealm().Addr(), targetPoolPath, rewardToken, startTimestamp, endTimestamp, std.GetHeight()) + incentiveId := incentiveIdByTime(startTimestamp, endTimestamp, caller, rewardToken) + + externalDuration := uint64(endTimestamp - startTimestamp) + if err := isValidIncentiveDuration(externalDuration); err != nil { + panic(err.Error()) + } + + pool := pools.GetOrCreate(targetPoolPath) + + incentive := NewExternalIncentive( + incentiveId, + targetPoolPath, + rewardToken, + rewardAmount, + startTimestamp, + endTimestamp, + caller, + std.GetHeight(), + depositGnsAmount, + time.Now().Unix(), + gns.GetAvgBlockTimeInMs(), + ) + + if externalIncentives.Has(incentiveId) { + panic(addDetailToError( + errIncentiveAlreadyExists, + ufmt.Sprintf("incentiveId(%s)", incentiveId), + )) + } + // store external incentive information for each incentiveId + externalIncentives.Set(incentiveId, incentive) + + // deposit gns amount + gns.TransferFrom(caller, consts.STAKER_ADDR, depositGnsAmount) + + // transfer reward token from user to staker + teller := common.GetTokenTeller(rewardToken) + teller.TransferFrom(caller, consts.STAKER_ADDR, rewardAmount) + + pool.incentives.create(caller, incentive) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "CreateExternalIncentive", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "incentiveId", incentiveId, + "targetPoolPath", targetPoolPath, + "rewardToken", rewardToken, + "rewardAmount", formatUint(rewardAmount), + "startTimestamp", formatInt(startTimestamp), + "endTimestamp", formatInt(endTimestamp), + "depositGnsAmount", formatUint(depositGnsAmount), + "currentHeight", formatInt(std.GetHeight()), + "currentTime", formatInt(time.Now().Unix()), + "avgBlockTimeInMs", formatUint(gns.GetAvgBlockTimeInMs()), + ) +} + +// EndExternalIncentive ends the external incentive and refunds the remaining reward +// ref: https://docs.gnoswap.io/contracts/staker/staker.gno#endexternalincentive +func EndExternalIncentive(refundee std.Address, targetPoolPath, rewardToken string, startTimestamp, endTimestamp, height int64) { + assertOnlyNotHalted() + + pool, exists := pools.Get(targetPoolPath) + if !exists { + panic(addDetailToError( + errDataNotFound, + ufmt.Sprintf("staker.gno__EndExternalIncentive() || targetPoolPath(%s) does not exist", targetPoolPath), + )) + } + + ictv, exists := pool.incentives.Get(startTimestamp, endTimestamp, refundee, rewardToken) + if !exists { + panic(addDetailToError( + errCannotEndIncentive, + ufmt.Sprintf("cannot end non existent incentive(%d:%d:%s:%s)", startTimestamp, endTimestamp, refundee, rewardToken), + )) + } + + currentHeight := std.GetHeight() + if currentHeight < ictv.endHeight { + panic(addDetailToError( + errCannotEndIncentive, + ufmt.Sprintf("cannot end incentive before endHeight(%d), current(%d)", ictv.endHeight, currentHeight), + )) + } + + // when incentive end time is over + // admin or refundee can end incentive ( left amount will be refunded ) + caller := getPrevAddr() + if caller != consts.ADMIN && caller != refundee { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only refundee(%s) or admin(%s) can end incentive, but called from %s", refundee, consts.ADMIN, caller), + )) + } + + // when incentive ended, refund remaining reward + refund := ictv.rewardLeft + + if !ictv.unclaimableRefunded { + refund += pool.incentives.calculateUnclaimableReward(ictv.incentiveId) + ictv.unclaimableRefunded = true + } + + poolLeftExternalRewardAmount := common.BalanceOf(ictv.rewardToken, consts.STAKER_ADDR) + if poolLeftExternalRewardAmount < refund { + refund = poolLeftExternalRewardAmount + } + + teller := common.GetTokenTeller(ictv.rewardToken) + teller.Transfer(ictv.refundee, refund) + + // unwrap if wugnot + if ictv.rewardToken == consts.WUGNOT_PATH { + unwrap(refund) + } + + // also refund deposit gns amount + gns.Transfer(ictv.refundee, ictv.depositGnsAmount) + + prevAddr, prevPkgPath := getPrev() + std.Emit( + "EndExternalIncentive", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "incentiveId", ictv.incentiveId, + "targetPoolPath", targetPoolPath, + "refundee", refundee.String(), + "refundToken", rewardToken, + "refundAmount", formatUint(refund), + "refundGnsAmount", formatUint(ictv.depositGnsAmount), + "isRequestUnwrap", ictv.rewardToken == consts.WUGNOT_PATH, + "externalIncentiveEndBy", getPrevAddr(), + ) +} + +func isValidIncentiveDuration(dur uint64) error { + switch dur { + case TIMESTAMP_90DAYS, TIMESTAMP_180DAYS, TIMESTAMP_365DAYS: + return nil + } + + return ufmt.Errorf( + "%v: externalDuration(%d) must be 90, 180, 365 days", + errInvalidIncentiveDuration, dur, + ) +} + +// checkStartTime checks whether the current time meets the conditions for generating an +// external rewards. Since the earliest time this reward can be generated is from +// midnight of the next day, it uses a timestamp to verify this timing. +func checkStartTime(startTimestamp int64) error { + // must be in seconds format, not milliseconds + // REF: https://stackoverflow.com/a/23982005 + numStr := strconv.Itoa(int(startTimestamp)) + if len(numStr) >= 13 { + return ufmt.Errorf( + "%v: startTimestamp(%d) must be in seconds format, not milliseconds", + errInvalidIncentiveStartTime, startTimestamp, + ) + } + + // must be at least +1 day midnight + tomorrowMidnight := time.Now().AddDate(0, 0, 1).Truncate(24 * time.Hour).Unix() + if startTimestamp < tomorrowMidnight { + return ufmt.Errorf( + "%v: startTimestamp(%d) must be at least +1 day midnight(%d)", + errInvalidIncentiveStartTime, startTimestamp, tomorrowMidnight, + ) + } + + // must be midnight of the day + startTime := time.Unix(startTimestamp, 0) + if !isMidnight(startTime) { + return ufmt.Errorf( + "%v: startTime(%d = %s) must be midnight of the day", + errInvalidIncentiveStartTime, startTimestamp, startTime.String(), + ) + } + + return nil +} + +func isMidnight(startTime time.Time) bool { + hour := startTime.Hour() + minute := startTime.Minute() + second := startTime.Second() + + return hour == 0 && minute == 0 && second == 0 +} + +func gnsBalance(addr std.Address) uint64 { + return gns.BalanceOf(addr) +} + +func isAllowedForExternalReward(poolPath, tokenPath string) error { + token0, token1, _ := poolPathDivide(poolPath) + + if tokenPath == token0 || tokenPath == token1 { + return nil + } + + allowed := contains(allowedTokens, tokenPath) + if allowed { + return nil + } + + return ufmt.Errorf( + "%v: tokenPath(%s) is not allowed for external reward for poolPath(%s)", + errNotAllowedForExternalReward, tokenPath, poolPath, + ) +} diff --git a/contract/r/gnoswap/staker/type.gno b/contract/r/gnoswap/staker/type.gno new file mode 100644 index 000000000..c6adb6432 --- /dev/null +++ b/contract/r/gnoswap/staker/type.gno @@ -0,0 +1,127 @@ +package staker + +import ( + "std" + + u256 "gno.land/p/gnoswap/uint256" +) + +// ExternalIncentive is a struct for storing external incentive information. +type ExternalIncentive struct { + incentiveId string // incentive id + startTimestamp int64 // start time for external reward + endTimestamp int64 // end time for external reward + createdHeight int64 // block height when the incentive was created + depositGnsAmount uint64 // deposited gns amount + targetPoolPath string // external reward target pool path + rewardToken string // external reward token path + rewardAmount uint64 // total reward amount + rewardLeft uint64 // remaining reward amount + startHeight int64 // start height for external reward + endHeight int64 // end height for external reward + rewardPerBlock uint64 // reward per block + refundee std.Address // refundee address + + unclaimableRefunded bool // whether unclaimable reward is refunded +} + +func (e ExternalIncentive) StartTimestamp() int64 { + return e.startTimestamp +} + +func (e ExternalIncentive) EndTimestamp() int64 { + return e.endTimestamp +} + +func (e ExternalIncentive) RewardToken() string { + return e.rewardToken +} + +func (e ExternalIncentive) RewardAmount() uint64 { + return e.rewardAmount +} + +func (self *ExternalIncentive) RewardSpent(currentHeight uint64) uint64 { + if currentHeight < uint64(self.startHeight) { + return 0 + } + + if currentHeight > uint64(self.endHeight) { + return self.rewardAmount + } + + blocks := currentHeight - uint64(self.startHeight) + rewardSpent := blocks * self.rewardPerBlock + return rewardSpent +} + +func (self *ExternalIncentive) RewardLeft(currentHeight uint64) uint64 { + if currentHeight <= uint64(self.startHeight) { + return self.rewardAmount + } + + if currentHeight > uint64(self.endHeight) { + return 0 + } + + if currentHeight == uint64(self.endHeight) { + return self.rewardPerBlock + } + + blocks := uint64(self.endHeight) - currentHeight + rewardLeft := blocks * self.rewardPerBlock + return rewardLeft +} + +func NewExternalIncentive( + incentiveId string, + targetPoolPath string, + rewardToken string, + rewardAmount uint64, + startTimestamp int64, // timestamp is in unix time(seconds) + endTimestamp int64, + refundee std.Address, + createdHeight int64, + depositGnsAmount uint64, + currentTime int64, // current time in unix time(seconds) + msPerBlock int64, // msPerBlock is in milliseconds +) *ExternalIncentive { + incentiveDuration := endTimestamp - startTimestamp + incentiveBlock := incentiveDuration * 1000 / msPerBlock + rewardPerBlock := rewardAmount / uint64(incentiveBlock) + + blocksLeftUntilStartHeight := (startTimestamp - currentTime) * 1000 / msPerBlock + blocksLeftUntilEndHeight := (endTimestamp - currentTime) * 1000 / msPerBlock + + startHeight := std.GetHeight() + blocksLeftUntilStartHeight + endHeight := std.GetHeight() + blocksLeftUntilEndHeight + + return &ExternalIncentive{ + incentiveId: incentiveId, + targetPoolPath: targetPoolPath, + rewardToken: rewardToken, + rewardAmount: rewardAmount, + startTimestamp: startTimestamp, + endTimestamp: endTimestamp, + startHeight: startHeight, + endHeight: endHeight, + rewardPerBlock: rewardPerBlock, + refundee: refundee, + createdHeight: createdHeight, + depositGnsAmount: depositGnsAmount, + unclaimableRefunded: false, + } +} + +type Deposit struct { + owner std.Address // owner address + stakeTimestamp int64 // staked time + stakeHeight int64 // staked block height + targetPoolPath string // staked position's pool path + tickLower int32 // tick lower + tickUpper int32 // tick upper + liquidity *u256.Uint // liquidity + lastCollectHeight int64 // last collect block height + warmups []Warmup // warmup information +} + diff --git a/contract/r/gnoswap/staker/utils.gno b/contract/r/gnoswap/staker/utils.gno new file mode 100644 index 000000000..c03b4ed55 --- /dev/null +++ b/contract/r/gnoswap/staker/utils.gno @@ -0,0 +1,205 @@ +package staker + +import ( + "std" + "strconv" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/ufmt" + + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +// GetOrigPkgAddr returns the original package address. +// In staker contract, original package address is the staker address. +func GetOrigPkgAddr() std.Address { + return consts.STAKER_ADDR +} + +// poolPathAlign ensures that a pool path is formatted with tokens in lexicographical order. +// +// This function takes a pool path string and splits it into three components: +// - Token0 address +// - Token1 address +// - Fee tier +// It ensures that the tokens are ordered lexicographically (Token0 < Token1) and reconstructs +// the pool path in the correct order. +// +// Parameters: +// - poolPath (string): The input pool path string in the format "Token0:Token1:Fee". +// +// Returns: +// - string: A lexicographically ordered pool path string. +// +// Panics: +// - If the input `poolPath` is invalid or cannot be split into exactly three parts. +func poolPathAlign(poolPath string) string { + res, err := common.Split(poolPath, ":", 3) + if err != nil { + panic(addDetailToError( + errInvalidPoolPath, + ufmt.Sprintf("invalid poolPath(%s)", poolPath), + )) + } + + pToken0, pToken1, fee := res[0], res[1], res[2] + + if pToken0 < pToken1 { + return ufmt.Sprintf("%s:%s:%s", pToken0, pToken1, fee) + } + + return ufmt.Sprintf("%s:%s:%s", pToken1, pToken0, fee) +} + +// poolPathDivide splits a pool path string into its components. +// +// The function takes a pool path string in the format "Token0:Token1:Fee", +// splits it using the `:` separator, and returns the three components: +// - `pToken0`: The first token address. +// - `pToken1`: The second token address. +// - `fee`: The fee tier. +// +// Parameters: +// - poolPath (string): The input pool path string. +// +// Returns: +// - string: `pToken0` - The first token address. +// - string: `pToken1` - The second token address. +// - string: `fee` - The fee tier. +// +// Panics: +// - If the `poolPath` cannot be split into exactly three components. +func poolPathDivide(poolPath string) (string, string, string) { + res, err := common.Split(poolPath, ":", 3) + if err != nil { + panic(errInvalidPoolPath) + } + + pToken0, pToken1, fee := res[0], res[1], res[2] + return pToken0, pToken1, fee +} + +// tid converts uint64 to grc721.TokenID. +// +// Input: +// - id: the uint64 to convert +// +// Output: +// - grc721.TokenID: the converted token ID +func tid(tokenId interface{}) grc721.TokenID { + if tokenId == nil { + panic(addDetailToError( + errDataNotFound, + "tokenId is nil", + )) + } + + switch tokenId.(type) { + case string: + return grc721.TokenID(tokenId.(string)) + case int: + return grc721.TokenID(strconv.Itoa(tokenId.(int))) + case uint64: + return grc721.TokenID(strconv.Itoa(int(tokenId.(uint64)))) + case grc721.TokenID: + return tokenId.(grc721.TokenID) + default: + panic(addDetailToError( + errInvalidInput, + ufmt.Sprintf("unsupported tokenId type(%T)", tokenId), + )) + } +} + +// max returns the larger of x or y. +func max(x, y int64) int64 { + if x > y { + return x + } + return y +} + +// min returns the smaller of x or y. +func min(x, y uint64) uint64 { + if x < y { + return x + } + return y +} + +// contains checks if a string is present in a slice of strings. +func contains(slice []string, item string) bool { + for _, element := range slice { + if element == item { + return true + } + } + return false +} + +// derivePkgAddr derives the Realm address from it's pkgpath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) +} + +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrevPkgPath returns the package path of the previous realm. +func getPrevPkgPath() string { + return std.PrevRealm().PkgPath() +} + +// getPrev returns the address and package path of the previous realm. +func getPrev() (string, string) { + prev := getPrevRealm() + return prev.Addr().String(), prev.PkgPath() +} + +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + +func formatUint(v interface{}) string { + switch v := v.(type) { + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatInt(v interface{}) string { + switch v := v.(type) { + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case int: + return strconv.Itoa(v) + default: + panic(ufmt.Sprintf("invalid type: %T", v)) + } +} + +func formatBool(v bool) string { + return strconv.FormatBool(v) +} diff --git a/contract/r/gnoswap/staker/utils_test.gno b/contract/r/gnoswap/staker/utils_test.gno new file mode 100644 index 000000000..b1c82a9cf --- /dev/null +++ b/contract/r/gnoswap/staker/utils_test.gno @@ -0,0 +1,274 @@ +package staker + +import ( + "std" + "testing" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/uassert" + + "gno.land/p/gnoswap/consts" +) + +func TestGetOrigPkgAddr(t *testing.T) { + std.TestSetOrigCaller(consts.STAKER_ADDR) + origPkgAddr := GetOrigPkgAddr() + if origPkgAddr != std.GetOrigCaller() { + t.Errorf("Expected %v, got %v", std.GetOrigCaller(), origPkgAddr) + } +} + +func TestPoolPathAlign(t *testing.T) { + tests := []struct { + input string + expected string + shouldPanic bool + }{ + // Valid cases + {"baz:bar:500", "bar:baz:500", false}, + {"bar:baz:500", "bar:baz:500", false}, + {"foo:bar:300", "bar:foo:300", false}, + {"bar:foo:300", "bar:foo:300", false}, + + // Invalid cases + {"invalid:path", "", true}, // Missing fee + {"bar:baz", "", true}, // Too few components + {"", "", true}, // Empty string + } + + for _, tt := range tests { + if tt.shouldPanic { + // Test for panic + defer func() { + if r := recover(); r == nil { + t.Errorf("poolPathAlign(%s) did not panic as expected", tt.input) + } + }() + _ = poolPathAlign(tt.input) + } else { + // Test normal cases + result := poolPathAlign(tt.input) + if result != tt.expected { + t.Errorf("poolPathAlign(%s) = %s; want %s", tt.input, result, tt.expected) + } + } + } +} + +func TestPoolPathDivide(t *testing.T) { + tests := []struct { + name string + input string + expected0 string + expected1 string + expectedFee string + shouldPanic bool + }{ + // Valid cases + {"Valid pool path", "bar:baz:500", "bar", "baz", "500", false}, + {"Another valid pool path", "foo:bar:300", "foo", "bar", "300", false}, + + // Invalid cases + {"Missing fee", "bar:baz", "", "", "", true}, + {"Too few components", "bar", "", "", "", true}, + {"Empty string", "", "", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + // Test for panic cases + defer func() { + if r := recover(); r == nil { + t.Errorf("poolPathDivide(%s) did not panic as expected", tt.input) + } + }() + _, _, _ = poolPathDivide(tt.input) + } else { + // Test for normal cases + pToken0, pToken1, fee := poolPathDivide(tt.input) + if pToken0 != tt.expected0 || pToken1 != tt.expected1 || fee != tt.expectedFee { + t.Errorf( + "poolPathDivide(%s) = (%s, %s, %s); want (%s, %s, %s)", + tt.input, pToken0, pToken1, fee, tt.expected0, tt.expected1, tt.expectedFee, + ) + } + } + }) + } +} + +func TestTid(t *testing.T) { + tests := []struct { + name string + input interface{} + expected string + shouldPanic bool + }{ + { + name: "Panic - nil", + input: nil, + expected: "[GNOSWAP-STAKER-022] requested data not found || tokenId is nil", + shouldPanic: true, + }, + { + name: "Panic - unsupported type", + input: float64(1), + expected: "[GNOSWAP-STAKER-007] invalid input data || unsupported tokenId type(unknown)", + shouldPanic: true, + }, + { + name: "Success - string", + input: "1", + expected: "1", + shouldPanic: false, + }, + { + name: "Success - int", + input: int(1), + expected: "1", + shouldPanic: false, + }, + { + name: "Success - uint64", + input: uint64(1), + expected: "1", + shouldPanic: false, + }, + { + name: "Success - grc721.TokenID", + input: grc721.TokenID("1"), + expected: "1", + shouldPanic: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + got := tid(tc.input) + uassert.Equal(t, tc.expected, string(got)) + } else { + tid(tc.input) + } + }) + } +} + +func TestGetPrevRealm(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prevRealm is User", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevRealm() + uassert.Equal(t, got.Addr().String(), tc.expected[0]) + uassert.Equal(t, got.PkgPath(), tc.expected[1]) + }) + } +} + +func TestGetPrevAddr(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected std.Address + }{ + { + name: "Success - prev Address is User", + originCaller: consts.ADMIN, + expected: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevAddr() + uassert.Equal(t, got.String(), tc.expected.String()) + }) + } +} + +func TestGetPrevAsString(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prev Realm of user info as string", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got1, got2 := getPrev() + uassert.Equal(t, got1, tc.expected[0]) + uassert.Equal(t, got2, tc.expected[1]) + }) + } +} + +func TestIsUserCall(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + originPkgPath string + expected bool + }{ + { + name: "Success - User Call", + originCaller: consts.ADMIN, + expected: true, + }, + { + name: "Failure - Not User Call", + originCaller: consts.ROUTER_ADDR, + originPkgPath: consts.ROUTER_PATH, + expected: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(tc.originCaller) + if !tc.expected { + std.TestSetRealm(std.NewCodeRealm(tc.originPkgPath)) + } + got := isUserCall() + uassert.Equal(t, got, tc.expected) + }) + } +} diff --git a/contract/r/gnoswap/staker/warp_unwrap_test.gno b/contract/r/gnoswap/staker/warp_unwrap_test.gno new file mode 100644 index 000000000..0a78c0d0f --- /dev/null +++ b/contract/r/gnoswap/staker/warp_unwrap_test.gno @@ -0,0 +1,160 @@ +package staker + +import ( + "std" + "strconv" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/gnoswap/consts" +) + +func TestWrap(t *testing.T) { + user1Addr := testutils.TestAddress("user1") + tests := []struct { + name string + action func() + verify func() uint64 + expected string + shouldPanic bool + }{ + { + name: "Failure - Amount less than minimum", + action: func() { + wrap(999) + }, + verify: nil, + expected: "[GNOSWAP-STAKER-006] can not wrapless than minimum amount || amount(999) < minimum(1000)", + shouldPanic: true, + }, + { + name: "Failure - Zero amount", + action: func() { + wrap(0) + }, + verify: nil, + expected: "[GNOSWAP-STAKER-005] wrap, unwrap failed || cannot wrap 0 ugnot", + shouldPanic: true, + }, + { + name: "Success - Valid amount", + action: func() { + std.TestSetRealm(std.NewUserRealm(user1Addr)) + ugnotFaucet(t, user1Addr, 1_000) + ugnotFaucet(t, consts.STAKER_ADDR, 1_000) + std.TestSetRealm(std.NewUserRealm(user1Addr)) + wrap(1_000) + }, + verify: func() uint64 { + return TokenBalance(t, wugnotPath, user1Addr) + }, + expected: "1000", + shouldPanic: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + switch v := r.(type) { + case string: + if v != tc.expected { + t.Errorf("Expected panic: %s, got: %s", tc.expected, v) + } + case error: + if v.Error() != tc.expected { + t.Errorf("Expected panic: %s, got: %s", tc.expected, v.Error()) + } + default: + t.Errorf("Unexpected panic type: %v", r) + } + } + }() + + if tc.shouldPanic { + tc.action() + } else { + tc.action() + if tc.verify != nil { + balance := tc.verify() + if strconv.FormatUint(balance, 10) != tc.expected { + t.Errorf("Expected balance: %s, got: %d", tc.expected, balance) + } + } + } + }) + } +} + +func TestUnwrap(t *testing.T) { + user2Addr := testutils.TestAddress("user2") + tests := []struct { + name string + action func() + verify func() uint64 + expected string + shouldPanic bool + }{ + { + name: "Failure - Zero amount", + action: func() { + unwrap(0) + }, + verify: nil, + expected: "", + shouldPanic: false, // No panic as zero amount is ignored + }, + { + name: "Success - Valid amount", + action: func() { + std.TestSetRealm(std.NewUserRealm(user2Addr)) + ugnotFaucet(t, user2Addr, 1_000) + ugnotFaucet(t, consts.STAKER_ADDR, 1_000) + std.TestSetRealm(std.NewUserRealm(user2Addr)) + wrap(1_000) + std.TestSetRealm(std.NewUserRealm(user2Addr)) + wugnotApprove(t, user2Addr, consts.STAKER_ADDR, 1_000) + unwrap(1_000) + }, + verify: func() uint64 { + return TokenBalance(t, wugnotPath, user2Addr) + }, + expected: "0", + shouldPanic: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + switch v := r.(type) { + case string: + if v != tc.expected { + t.Errorf("Expected panic: %s, got: %s", tc.expected, v) + } + case error: + if v.Error() != tc.expected { + t.Errorf("Expected panic: %s, got: %s", tc.expected, v.Error()) + } + default: + t.Errorf("Unexpected panic type: %v", r) + } + } + }() + + if tc.shouldPanic { + tc.action() + } else { + tc.action() + if tc.verify != nil { + balance := tc.verify() + if strconv.FormatUint(balance, 10) != tc.expected { + t.Errorf("Expected balance: %s, got: %d", tc.expected, balance) + } + } + } + }) + } +} diff --git a/contract/r/gnoswap/staker/wrap_unwrap.gno b/contract/r/gnoswap/staker/wrap_unwrap.gno new file mode 100644 index 000000000..c90a55f59 --- /dev/null +++ b/contract/r/gnoswap/staker/wrap_unwrap.gno @@ -0,0 +1,66 @@ +package staker + +import ( + "std" + + "gno.land/r/demo/wugnot" + + "gno.land/p/demo/ufmt" + "gno.land/p/gnoswap/consts" + "gno.land/r/gnoswap/v1/common" +) + +// wrap converts `ugnot` tokens into `wugnot` tokens. +// +// Parameters: +// - ugnotAmount (uint64): The amount of `ugnot` to wrap. +// +// Panics: +// - If `ugnotAmount` is less than or equal to 0. +// - If `ugnotAmount` is less than the minimum deposit required (`consts.UGNOT_MIN_DEPOSIT_TO_WRAP`). +func wrap(ugnotAmount uint64) { + if ugnotAmount <= 0 { + panic(addDetailToError( + errWrapUnwrap, + "cannot wrap 0 ugnot", + )) + } + + if ugnotAmount < consts.UGNOT_MIN_DEPOSIT_TO_WRAP { + panic(addDetailToError( + errWugnotMinimum, + ufmt.Sprintf("amount(%d) < minimum(%d)", ugnotAmount, consts.UGNOT_MIN_DEPOSIT_TO_WRAP), + )) + } + + // WRAP IT + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(consts.STAKER_ADDR, consts.WUGNOT_ADDR, std.Coins{{Denom: "ugnot", Amount: int64(ugnotAmount)}}) + wugnot.Deposit() // STAKER HAS WUGNOT + + // SEND WUGNOT: STAKER -> USER + wugnot.Transfer(common.AddrToUser(std.PrevRealm().Addr()), ugnotAmount) +} + +// unwrap converts `wugnot` tokens back into `ugnot` tokens. +// +// Parameters: +// - wugnotAmount (uint64): The amount of `wugnot` to unwrap. +// +// Note: +// - If `wugnotAmount` is 0, the function simply returns without executing further steps. +func unwrap(wugnotAmount uint64) { + if wugnotAmount == 0 { + return + } + + // SEND WUGNOT: USER -> STAKER + wugnot.TransferFrom(common.AddrToUser(std.PrevRealm().Addr()), common.AddrToUser(consts.STAKER_ADDR), wugnotAmount) + + // UNWRAP IT + wugnot.Withdraw(wugnotAmount) + + // SEND GNOT: STAKER -> USER + banker := std.GetBanker(std.BankerTypeRealmSend) + banker.SendCoins(consts.STAKER_ADDR, std.PrevRealm().Addr(), std.Coins{{Denom: "ugnot", Amount: int64(wugnotAmount)}}) +} diff --git a/contract/r/gnoswap/test_token/bar/bar.gno b/contract/r/gnoswap/test_token/bar/bar.gno new file mode 100644 index 000000000..cca7288c0 --- /dev/null +++ b/contract/r/gnoswap/test_token/bar/bar.gno @@ -0,0 +1,75 @@ +package bar + +import ( + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/grc20reg" +) + +var ( + Token, privateLedger = grc20.NewToken("Bar", "BAR", 6) + UserTeller = Token.CallerTeller() + owner = ownable.NewWithAddress("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d") // ADMIN +) + +func init() { + privateLedger.Mint(owner.Owner(), 100_000_000_000_000) + getter := func() *grc20.Token { return Token } + grc20reg.Register(getter, "") +} + +func TotalSupply() uint64 { + return UserTeller.TotalSupply() +} + +func BalanceOf(owner std.Address) uint64 { + return UserTeller.BalanceOf(owner) +} + +func Allowance(owner, spender std.Address) uint64 { + return UserTeller.Allowance(owner, spender) +} + +func Transfer(to std.Address, amount uint64) { + checkErr(UserTeller.Transfer(to, amount)) +} + +func Approve(spender std.Address, amount uint64) { + checkErr(UserTeller.Approve(spender, amount)) +} + +func TransferFrom(from, to std.Address, amount uint64) { + checkErr(UserTeller.TransferFrom(from, to, amount)) +} + +func Burn(from std.Address, amount uint64) { + owner.AssertCallerIsOwner() + checkErr(privateLedger.Burn(from, amount)) +} + +func Render(path string) string { + parts := strings.Split(path, "/") + c := len(parts) + + switch { + case path == "": + return Token.RenderHome() + case c == 2 && parts[0] == "balance": + owner := std.Address(parts[1]) + balance := UserTeller.BalanceOf(owner) + return ufmt.Sprintf("%d\n", balance) + default: + return "404\n" + } +} + +func checkErr(err error) { + if err != nil { + panic(err) + } +} diff --git a/contract/r/gnoswap/test_token/bar/gno.mod b/contract/r/gnoswap/test_token/bar/gno.mod new file mode 100644 index 000000000..3cb44e9bd --- /dev/null +++ b/contract/r/gnoswap/test_token/bar/gno.mod @@ -0,0 +1 @@ +module gno.land/r/onbloc/bar \ No newline at end of file diff --git a/contract/r/gnoswap/test_token/baz/baz.gno b/contract/r/gnoswap/test_token/baz/baz.gno new file mode 100644 index 000000000..ded2890ad --- /dev/null +++ b/contract/r/gnoswap/test_token/baz/baz.gno @@ -0,0 +1,75 @@ +package baz + +import ( + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/grc20reg" +) + +var ( + Token, privateLedger = grc20.NewToken("Baz", "BAZ", 6) + UserTeller = Token.CallerTeller() + owner = ownable.NewWithAddress("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d") // ADMIN +) + +func init() { + privateLedger.Mint(owner.Owner(), 100_000_000_000_000) + getter := func() *grc20.Token { return Token } + grc20reg.Register(getter, "") +} + +func TotalSupply() uint64 { + return UserTeller.TotalSupply() +} + +func BalanceOf(owner std.Address) uint64 { + return UserTeller.BalanceOf(owner) +} + +func Allowance(owner, spender std.Address) uint64 { + return UserTeller.Allowance(owner, spender) +} + +func Transfer(to std.Address, amount uint64) { + checkErr(UserTeller.Transfer(to, amount)) +} + +func Approve(spender std.Address, amount uint64) { + checkErr(UserTeller.Approve(spender, amount)) +} + +func TransferFrom(from, to std.Address, amount uint64) { + checkErr(UserTeller.TransferFrom(from, to, amount)) +} + +func Burn(from std.Address, amount uint64) { + owner.AssertCallerIsOwner() + checkErr(privateLedger.Burn(from, amount)) +} + +func Render(path string) string { + parts := strings.Split(path, "/") + c := len(parts) + + switch { + case path == "": + return Token.RenderHome() + case c == 2 && parts[0] == "balance": + owner := std.Address(parts[1]) + balance := UserTeller.BalanceOf(owner) + return ufmt.Sprintf("%d\n", balance) + default: + return "404\n" + } +} + +func checkErr(err error) { + if err != nil { + panic(err) + } +} diff --git a/contract/r/gnoswap/test_token/baz/gno.mod b/contract/r/gnoswap/test_token/baz/gno.mod new file mode 100644 index 000000000..52e2bfcff --- /dev/null +++ b/contract/r/gnoswap/test_token/baz/gno.mod @@ -0,0 +1 @@ +module gno.land/r/onbloc/baz \ No newline at end of file diff --git a/contract/r/gnoswap/test_token/foo/foo.gno b/contract/r/gnoswap/test_token/foo/foo.gno new file mode 100644 index 000000000..59b6e3a9f --- /dev/null +++ b/contract/r/gnoswap/test_token/foo/foo.gno @@ -0,0 +1,75 @@ +package foo + +import ( + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/grc20reg" +) + +var ( + Token, privateLedger = grc20.NewToken("Foo", "FOO", 6) + UserTeller = Token.CallerTeller() + owner = ownable.NewWithAddress("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d") // ADMIN +) + +func init() { + privateLedger.Mint(owner.Owner(), 100_000_000_000_000) + getter := func() *grc20.Token { return Token } + grc20reg.Register(getter, "") +} + +func TotalSupply() uint64 { + return UserTeller.TotalSupply() +} + +func BalanceOf(owner std.Address) uint64 { + return UserTeller.BalanceOf(owner) +} + +func Allowance(owner, spender std.Address) uint64 { + return UserTeller.Allowance(owner, spender) +} + +func Transfer(to std.Address, amount uint64) { + checkErr(UserTeller.Transfer(to, amount)) +} + +func Approve(spender std.Address, amount uint64) { + checkErr(UserTeller.Approve(spender, amount)) +} + +func TransferFrom(from, to std.Address, amount uint64) { + checkErr(UserTeller.TransferFrom(from, to, amount)) +} + +func Burn(from std.Address, amount uint64) { + owner.AssertCallerIsOwner() + checkErr(privateLedger.Burn(from, amount)) +} + +func Render(path string) string { + parts := strings.Split(path, "/") + c := len(parts) + + switch { + case path == "": + return Token.RenderHome() + case c == 2 && parts[0] == "balance": + owner := std.Address(parts[1]) + balance := UserTeller.BalanceOf(owner) + return ufmt.Sprintf("%d\n", balance) + default: + return "404\n" + } +} + +func checkErr(err error) { + if err != nil { + panic(err) + } +} diff --git a/contract/r/gnoswap/test_token/foo/gno.mod b/contract/r/gnoswap/test_token/foo/gno.mod new file mode 100644 index 000000000..d0c856160 --- /dev/null +++ b/contract/r/gnoswap/test_token/foo/gno.mod @@ -0,0 +1 @@ +module gno.land/r/onbloc/foo \ No newline at end of file diff --git a/contract/r/gnoswap/test_token/obl/gno.mod b/contract/r/gnoswap/test_token/obl/gno.mod new file mode 100644 index 000000000..1799581b2 --- /dev/null +++ b/contract/r/gnoswap/test_token/obl/gno.mod @@ -0,0 +1 @@ +module gno.land/r/onbloc/obl \ No newline at end of file diff --git a/contract/r/gnoswap/test_token/obl/obl.gno b/contract/r/gnoswap/test_token/obl/obl.gno new file mode 100644 index 000000000..a37ca265e --- /dev/null +++ b/contract/r/gnoswap/test_token/obl/obl.gno @@ -0,0 +1,75 @@ +package obl + +import ( + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/grc20reg" +) + +var ( + Token, privateLedger = grc20.NewToken("Obl", "OBL", 6) + UserTeller = Token.CallerTeller() + owner = ownable.NewWithAddress("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d") // ADMIN +) + +func init() { + privateLedger.Mint(owner.Owner(), 100_000_000_000_000) + getter := func() *grc20.Token { return Token } + grc20reg.Register(getter, "") +} + +func TotalSupply() uint64 { + return UserTeller.TotalSupply() +} + +func BalanceOf(owner std.Address) uint64 { + return UserTeller.BalanceOf(owner) +} + +func Allowance(owner, spender std.Address) uint64 { + return UserTeller.Allowance(owner, spender) +} + +func Transfer(to std.Address, amount uint64) { + checkErr(UserTeller.Transfer(to, amount)) +} + +func Approve(spender std.Address, amount uint64) { + checkErr(UserTeller.Approve(spender, amount)) +} + +func TransferFrom(from, to std.Address, amount uint64) { + checkErr(UserTeller.TransferFrom(from, to, amount)) +} + +func Burn(from std.Address, amount uint64) { + owner.AssertCallerIsOwner() + checkErr(privateLedger.Burn(from, amount)) +} + +func Render(path string) string { + parts := strings.Split(path, "/") + c := len(parts) + + switch { + case path == "": + return Token.RenderHome() + case c == 2 && parts[0] == "balance": + owner := std.Address(parts[1]) + balance := UserTeller.BalanceOf(owner) + return ufmt.Sprintf("%d\n", balance) + default: + return "404\n" + } +} + +func checkErr(err error) { + if err != nil { + panic(err) + } +} diff --git a/contract/r/gnoswap/test_token/qux/gno.mod b/contract/r/gnoswap/test_token/qux/gno.mod new file mode 100644 index 000000000..3671f03d4 --- /dev/null +++ b/contract/r/gnoswap/test_token/qux/gno.mod @@ -0,0 +1 @@ +module gno.land/r/onbloc/qux \ No newline at end of file diff --git a/contract/r/gnoswap/test_token/qux/qux.gno b/contract/r/gnoswap/test_token/qux/qux.gno new file mode 100644 index 000000000..9fd7b44c9 --- /dev/null +++ b/contract/r/gnoswap/test_token/qux/qux.gno @@ -0,0 +1,75 @@ +package qux + +import ( + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/grc20reg" +) + +var ( + Token, privateLedger = grc20.NewToken("Qux", "QUX", 6) + UserTeller = Token.CallerTeller() + owner = ownable.NewWithAddress("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d") // ADMIN +) + +func init() { + privateLedger.Mint(owner.Owner(), 100_000_000_000_000) + getter := func() *grc20.Token { return Token } + grc20reg.Register(getter, "") +} + +func TotalSupply() uint64 { + return UserTeller.TotalSupply() +} + +func BalanceOf(owner std.Address) uint64 { + return UserTeller.BalanceOf(owner) +} + +func Allowance(owner, spender std.Address) uint64 { + return UserTeller.Allowance(owner, spender) +} + +func Transfer(to std.Address, amount uint64) { + checkErr(UserTeller.Transfer(to, amount)) +} + +func Approve(spender std.Address, amount uint64) { + checkErr(UserTeller.Approve(spender, amount)) +} + +func TransferFrom(from, to std.Address, amount uint64) { + checkErr(UserTeller.TransferFrom(from, to, amount)) +} + +func Burn(from std.Address, amount uint64) { + owner.AssertCallerIsOwner() + checkErr(privateLedger.Burn(from, amount)) +} + +func Render(path string) string { + parts := strings.Split(path, "/") + c := len(parts) + + switch { + case path == "": + return Token.RenderHome() + case c == 2 && parts[0] == "balance": + owner := std.Address(parts[1]) + balance := UserTeller.BalanceOf(owner) + return ufmt.Sprintf("%d\n", balance) + default: + return "404\n" + } +} + +func checkErr(err error) { + if err != nil { + panic(err) + } +} diff --git a/contract/r/gnoswap/test_token/usdc/gno.mod b/contract/r/gnoswap/test_token/usdc/gno.mod new file mode 100644 index 000000000..3d7ccb60c --- /dev/null +++ b/contract/r/gnoswap/test_token/usdc/gno.mod @@ -0,0 +1 @@ +module gno.land/r/onbloc/usdc \ No newline at end of file diff --git a/contract/r/gnoswap/test_token/usdc/usdc.gno b/contract/r/gnoswap/test_token/usdc/usdc.gno new file mode 100644 index 000000000..d83bc93a2 --- /dev/null +++ b/contract/r/gnoswap/test_token/usdc/usdc.gno @@ -0,0 +1,75 @@ +package usdc + +import ( + "std" + "strings" + + "gno.land/p/demo/grc/grc20" + "gno.land/p/demo/ownable" + "gno.land/p/demo/ufmt" + + "gno.land/r/demo/grc20reg" +) + +var ( + Token, privateLedger = grc20.NewToken("Usd Coin", "USDC", 6) + UserTeller = Token.CallerTeller() + owner = ownable.NewWithAddress("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d") // ADMIN +) + +func init() { + privateLedger.Mint(owner.Owner(), 100_000_000_000_000) + getter := func() *grc20.Token { return Token } + grc20reg.Register(getter, "") +} + +func TotalSupply() uint64 { + return UserTeller.TotalSupply() +} + +func BalanceOf(owner std.Address) uint64 { + return UserTeller.BalanceOf(owner) +} + +func Allowance(owner, spender std.Address) uint64 { + return UserTeller.Allowance(owner, spender) +} + +func Transfer(to std.Address, amount uint64) { + checkErr(UserTeller.Transfer(to, amount)) +} + +func Approve(spender std.Address, amount uint64) { + checkErr(UserTeller.Approve(spender, amount)) +} + +func TransferFrom(from, to std.Address, amount uint64) { + checkErr(UserTeller.TransferFrom(from, to, amount)) +} + +func Burn(from std.Address, amount uint64) { + owner.AssertCallerIsOwner() + checkErr(privateLedger.Burn(from, amount)) +} + +func Render(path string) string { + parts := strings.Split(path, "/") + c := len(parts) + + switch { + case path == "": + return Token.RenderHome() + case c == 2 && parts[0] == "balance": + owner := std.Address(parts[1]) + balance := UserTeller.BalanceOf(owner) + return ufmt.Sprintf("%d\n", balance) + default: + return "404\n" + } +} + +func checkErr(err error) { + if err != nil { + panic(err) + } +} diff --git a/tests/grc20reg/approve_transferfrom.txtar b/tests/grc20reg/approve_transferfrom.txtar new file mode 100644 index 000000000..ff43ac448 --- /dev/null +++ b/tests/grc20reg/approve_transferfrom.txtar @@ -0,0 +1,53 @@ +loadpkg gno.land/p/demo/users + +loadpkg gno.land/r/demo/foo20 +loadpkg gno.land/r/demo/grc20reg + +loadpkg gno.land/r/demo/reg $WORK/reg + +## start a new node +gnoland start + +## faucet +# gnokey maketx call -pkgpath gno.land/r/demo/foo20 -func Faucet -gas-fee 1ugnot -gas-wanted 4000000 -broadcast -chainid=tendermint_test test1 + +## print reg addr +gnokey maketx call -pkgpath gno.land/r/demo/reg -func RelamAddr -gas-fee 1ugnot -gas-wanted 4000000 -broadcast -chainid=tendermint_test test1 +stdout 'g19tlskvga928es8ug2empargp0teul03apzjud9' + +## approve +gnokey maketx call -pkgpath gno.land/r/demo/foo20 -func Approve -args 'g19tlskvga928es8ug2empargp0teul03apzjud9' -args '100' -gas-fee 1ugnot -gas-wanted 4000000 -broadcast -chainid=tendermint_test test1 + +## transfer from +gnokey maketx call -pkgpath gno.land/r/demo/reg -func TransferFromWithReg -gas-fee 1ugnot -gas-wanted 4000000 -broadcast -chainid=tendermint_test test1 +stdout '' + +-- reg/reg.gno -- +package reg + +import ( + "std" + + "gno.land/r/demo/grc20reg" +) + +func RelamAddr() string { + addr := std.CurrentRealm().Addr().String() + return addr +} + +func TransferFromWithReg() { + caller := std.PrevRealm().Addr() + curr := std.CurrentRealm().Addr() + + + // using import + // foo20.TransferFrom(uCaller, uCurr, uint64(100)) + + // using grc20reg + fooTokenGetter := grc20reg.Get("gno.land/r/demo/foo20") + fooToken := fooTokenGetter() + userTeller := fooToken.CallerTeller() + + userTeller.TransferFrom(caller, curr, uint64(100)) +} \ No newline at end of file