diff --git a/core/vm/contracts.go b/core/vm/contracts.go index d0e3e6913917..9a52616657bb 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/crypto/bls12381" "github.com/ethereum/go-ethereum/crypto/bn256" "github.com/ethereum/go-ethereum/params" + big2 "github.com/holiman/big" "golang.org/x/crypto/ripemd160" ) @@ -377,23 +378,19 @@ func (c *bigModExp) Run(input []byte) ([]byte, error) { } // Retrieve the operands and execute the exponentiation var ( - base = new(big.Int).SetBytes(getData(input, 0, baseLen)) - exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) - mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) + base = new(big2.Int).SetBytes(getData(input, 0, baseLen)) + exp = new(big2.Int).SetBytes(getData(input, baseLen, expLen)) + mod = new(big2.Int).SetBytes(getData(input, baseLen+expLen, modLen)) v []byte ) switch { case mod.BitLen() == 0: // Modulo 0 is undefined, return zero return common.LeftPadBytes([]byte{}, int(modLen)), nil - case base.Cmp(common.Big1) == 0: + case base.BitLen() == 1: // a bit length of 1 means it's 1 (or -1). //If base == 1, then we can just return base % mod (if mod >= 1, which it is) v = base.Mod(base, mod).Bytes() - //case mod.Bit(0) == 0: - // // Modulo is even - // v = math.FastExp(base, exp, mod).Bytes() default: - // Modulo is odd v = base.Exp(base, exp, mod).Bytes() } return common.LeftPadBytes(v, int(modLen)), nil