Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[x/programs] Improve overflow checks #600

Merged
merged 15 commits into from
Nov 17, 2023
20 changes: 10 additions & 10 deletions x/programs/examples/counter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ func TestCounterProgram(t *testing.T) {
// create counter for alice on program 1
result, err := rt.Call(ctx, "initialize_address", programIDPtr, alicePtr)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

// validate counter at 0
result, err = rt.Call(ctx, "get_value", programIDPtr, alicePtr)
require.NoError(err)
require.Equal(uint64(0), result[0])
require.Equal(int64(0), result[0])

// initialize second runtime to create second counter program with an empty
// meter.
Expand Down Expand Up @@ -102,16 +102,16 @@ func TestCounterProgram(t *testing.T) {
// initialize counter for alice on runtime 2
result, err = rt2.Call(ctx, "initialize_address", programID2Ptr, alicePtr2)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

// increment alice's counter on program 2 by 10
result, err = rt2.Call(ctx, "inc", programID2Ptr, alicePtr2, 10)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

result, err = rt2.Call(ctx, "get_value", programID2Ptr, alicePtr2)
require.NoError(err)
require.Equal(uint64(10), result[0])
require.Equal(int64(10), result[0])

// stop the runtime to prevent further execution
rt2.Stop()
Expand All @@ -127,13 +127,13 @@ func TestCounterProgram(t *testing.T) {
// increment alice's counter on program 1
result, err = rt.Call(ctx, "inc", programIDPtr, alicePtr, 1)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

result, err = rt.Call(ctx, "get_value", programIDPtr, alicePtr)
require.NoError(err)

log.Debug("count program 1",
zap.Uint64("alice", result[0]),
zap.Int64("alice", result[0]),
)

// write program id 2 to stack of program 1
Expand All @@ -142,17 +142,17 @@ func TestCounterProgram(t *testing.T) {

caller := programIDPtr
target := programID2Ptr
maxUnitsProgramToProgram := uint64(10000)
maxUnitsProgramToProgram := int64(10000)

// increment alice's counter on program 2
result, err = rt.Call(ctx, "inc_external", caller, target, maxUnitsProgramToProgram, alicePtr, 5)
require.NoError(err)
require.Equal(uint64(1), result[0])
require.Equal(int64(1), result[0])

// expect alice's counter on program 2 to be 15
result, err = rt.Call(ctx, "get_value_external", caller, target, maxUnitsProgramToProgram, alicePtr)
require.NoError(err)
require.Equal(uint64(15), result[0])
require.Equal(int64(15), result[0])

require.Greater(rt.Meter().GetBalance(), uint64(0))
}
6 changes: 3 additions & 3 deletions x/programs/examples/imports/program/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ func (i *Import) callProgramFn(
return int64(res[0])
}

func getCallArgs(ctx context.Context, memory runtime.Memory, buffer []byte, invokeProgramID uint64) ([]uint64, error) {
func getCallArgs(ctx context.Context, memory runtime.Memory, buffer []byte, invokeProgramID int64) ([]int64, error) {
// first arg contains id of program to call
args := []uint64{invokeProgramID}
args := []int64{invokeProgramID}
p := codec.NewReader(buffer, len(buffer))
i := 0
for !p.Empty() {
size := p.UnpackInt64(true)
isInt := p.UnpackBool()
if isInt {
valueInt := p.UnpackUint64(true)
valueInt := p.UnpackInt64(true)
args = append(args, valueInt)
} else {
valueBytes := make([]byte, size)
Expand Down
28 changes: 14 additions & 14 deletions x/programs/examples/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ func (t *Token) Run(ctx context.Context) error {
}

t.log.Debug("init response",
zap.Uint64("init", resp[0]),
zap.Int64("init", resp[0]),
)

result, err := rt.Call(ctx, "get_total_supply", programIDPtr)
if err != nil {
return err
}
t.log.Debug("total supply",
zap.Uint64("minted", result[0]),
zap.Int64("minted", result[0]),
)

// generate alice keys
Expand Down Expand Up @@ -116,13 +116,13 @@ func (t *Token) Run(ctx context.Context) error {
)

// mint 100 tokens to alice
mintAlice := uint64(1000)
mintAlice := int64(1000)
_, err = rt.Call(ctx, "mint_to", programIDPtr, alicePtr, mintAlice)
if err != nil {
return err
}
t.log.Debug("minted",
zap.Uint64("alice", mintAlice),
zap.Int64("alice", mintAlice),
)

// check balance of alice
Expand All @@ -131,7 +131,7 @@ func (t *Token) Run(ctx context.Context) error {
return err
}
t.log.Debug("balance",
zap.Uint64("alice", result[0]),
zap.Int64("alice", result[0]),
)

// check balance of bob
Expand All @@ -140,27 +140,27 @@ func (t *Token) Run(ctx context.Context) error {
return err
}
t.log.Debug("balance",
zap.Uint64("bob", result[0]),
zap.Int64("bob", result[0]),
)

// transfer 50 from alice to bob
transferToBob := uint64(50)
transferToBob := int64(50)
_, err = rt.Call(ctx, "transfer", programIDPtr, alicePtr, bobPtr, transferToBob)
if err != nil {
return err
}
t.log.Debug("transferred",
zap.Uint64("alice", transferToBob),
zap.Uint64("to bob", transferToBob),
zap.Int64("alice", transferToBob),
zap.Int64("to bob", transferToBob),
)

_, err = rt.Call(ctx, "transfer", programIDPtr, alicePtr, bobPtr, 1)
if err != nil {
return err
}
t.log.Debug("transferred",
zap.Uint64("alice", transferToBob),
zap.Uint64("to bob", transferToBob),
zap.Int64("alice", transferToBob),
zap.Int64("to bob", transferToBob),
)

// get balance alice
Expand All @@ -169,15 +169,15 @@ func (t *Token) Run(ctx context.Context) error {
return err
}
t.log.Debug("balance",
zap.Uint64("alice", result[0]),
zap.Int64("alice", result[0]),
)

// get balance bob
result, err = rt.Call(ctx, "get_balance", programIDPtr, bobPtr)
if err != nil {
return err
}
t.log.Debug("balance", zap.Uint64("bob", result[0]))
t.log.Debug("balance", zap.Int64("bob", result[0]))

t.log.Debug("remaining balance",
zap.Uint64("unit", rt.Meter().GetBalance()),
Expand Down Expand Up @@ -221,7 +221,7 @@ func (t *Token) RunShort(ctx context.Context) error {
}

t.log.Debug("init response",
zap.Uint64("init", resp[0]),
zap.Int64("init", resp[0]),
)
return nil
}
11 changes: 7 additions & 4 deletions x/programs/examples/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@ import (
"github.com/ava-labs/hypersdk/x/programs/runtime"
)

func newKeyPtr(ctx context.Context, key ed25519.PublicKey, runtime runtime.Runtime) (uint64, error) {
ptr, err := runtime.Memory().Alloc(ed25519.PublicKeyLen)
func newKeyPtr(ctx context.Context, key ed25519.PublicKey, rt runtime.Runtime) (int64, error) {
ptr, err := rt.Memory().Alloc(ed25519.PublicKeyLen)
if err != nil {
return 0, err
}

// write programID to memory which we will later pass to the program
err = runtime.Memory().Write(ptr, key[:])
err = rt.Memory().Write(ptr, key[:])
if err != nil {
return 0, err
}

return ptr, err
if ptr > runtime.MaxInt64 {
samliok marked this conversation as resolved.
Show resolved Hide resolved
return 0, runtime.ErrIntegerConversionOverflow
}
return int64(ptr), err
}

func newKey() (ed25519.PrivateKey, ed25519.PublicKey, error) {
Expand Down
8 changes: 7 additions & 1 deletion x/programs/runtime/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

package runtime

import "github.com/ava-labs/avalanchego/utils/units"
import (
"math"

"github.com/ava-labs/avalanchego/utils/units"
)

const (
AllocFnName = "alloc"
Expand All @@ -12,4 +16,6 @@ const (
guestSuffix = "_guest"
wasiPreview1ModName = "wasi_snapshot_preview1"
MemoryPageSize = 64 * units.KiB
MaxInt64 = math.MaxInt64
MinInt64 = math.MinInt64
samliok marked this conversation as resolved.
Show resolved Hide resolved
)
6 changes: 4 additions & 2 deletions x/programs/runtime/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ type Runtime interface {
// units. The engine will handle the compile strategy and instantiate the
// module with the given imports. Initialize should only be called once.
Initialize(context.Context, []byte, uint64) error
// Call invokes the an exported guest function with the given parameters.
Call(context.Context, string, ...uint64) ([]uint64, error)
// Call invokes an exported guest function with the given parameters.
// Returns the results of the call or an error if the call failed.
// If there are 0 results this value is set to nil.
samliok marked this conversation as resolved.
Show resolved Hide resolved
Call(context.Context, string, ...int64) ([]int64, error)
hexfusion marked this conversation as resolved.
Show resolved Hide resolved
// Memory returns the runtime memory.
Memory() Memory
// Meter returns the runtime meter.
Expand Down
1 change: 1 addition & 0 deletions x/programs/runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ var (
ErrInsufficientUnits = errors.New("insufficient units")
ErrRuntimeStoreSet = errors.New("runtime store has already been set")
ErrNegativeValue = errors.New("negative value")
ErrIntegerConversionOverflow = errors.New("integer overflow during conversion")

// Trap errors
ErrTrapStackOverflow = errors.New("the current stack space was exhausted")
Expand Down
19 changes: 13 additions & 6 deletions x/programs/runtime/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (m *memory) Len() (uint64, error) {

// WriteBytes is a helper function that allocates memory and writes the given
// bytes to the memory returning the offset.
func WriteBytes(m Memory, buf []byte) (uint64, error) {
func WriteBytes(m Memory, buf []byte) (int64, error) {
hexfusion marked this conversation as resolved.
Show resolved Hide resolved
offset, err := m.Alloc(uint64(len(buf)))
if err != nil {
return 0, err
Expand All @@ -121,7 +121,11 @@ func WriteBytes(m Memory, buf []byte) (uint64, error) {
return 0, err
}

return offset, nil
if offset > MaxInt64 {
samliok marked this conversation as resolved.
Show resolved Hide resolved
return 0, fmt.Errorf("write bytes failed: %w", ErrInvalidMemoryAddress)
}

return int64(offset), nil
}

// CallParam defines a value to be passed to a guest function.
Expand All @@ -131,8 +135,8 @@ type CallParam struct {

// WriteParams is a helper function that writes the given params to memory if non integer.
// Supported types include int, uint64 and string.
func WriteParams(m Memory, p []CallParam) ([]uint64, error) {
params := []uint64{}
func WriteParams(m Memory, p []CallParam) ([]int64, error) {
params := []int64{}
for _, param := range p {
switch v := param.Value.(type) {
case string:
Expand All @@ -145,9 +149,12 @@ func WriteParams(m Memory, p []CallParam) ([]uint64, error) {
if v < 0 {
return nil, fmt.Errorf("failed to write param: %w", ErrNegativeValue)
}
params = append(params, uint64(v))
params = append(params, int64(v))
case uint64:
params = append(params, v)
if v > MaxInt64 {
return nil, fmt.Errorf("failed to write param: %w", ErrIntegerConversionOverflow)
}
params = append(params, int64(v))
default:
return nil, fmt.Errorf("%w: support types int, uint64 and string", ErrInvalidParamType)
}
Expand Down
21 changes: 14 additions & 7 deletions x/programs/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func getRegisteredImportModules(importTypes []*wasmtime.ImportType) []string {
return imports
}

func (r *WasmRuntime) Call(_ context.Context, name string, params ...uint64) ([]uint64, error) {
func (r *WasmRuntime) Call(_ context.Context, name string, params ...int64) ([]int64, error) {
var fnName string
switch name {
case AllocFnName, DeallocFnName, MemoryFnName:
Expand Down Expand Up @@ -184,11 +184,14 @@ func (r *WasmRuntime) Call(_ context.Context, name string, params ...uint64) ([]

switch v := result.(type) {
case int32:
value := uint64(result.(int32))
return []uint64{value}, nil
value := int64(result.(int32))
return []int64{value}, nil
case int64:
value := uint64(result.(int64))
return []uint64{value}, nil
value := result.(int64)
return []int64{value}, nil
case nil:
// the function had no return values
return nil, nil
default:
return nil, fmt.Errorf("invalid result type: %v", v)
}
Expand Down Expand Up @@ -239,14 +242,18 @@ func PreCompileWasmBytes(programBytes []byte, cfg *Config) ([]byte, error) {
}

// mapFunctionParams maps call input to the expected wasm function params.
func mapFunctionParams(input []uint64, values []*wasmtime.ValType) ([]interface{}, error) {
func mapFunctionParams(input []int64, values []*wasmtime.ValType) ([]interface{}, error) {
params := make([]interface{}, len(values))
for i, v := range values {
switch v.Kind() {
case wasmtime.KindI32:
samliok marked this conversation as resolved.
Show resolved Hide resolved
// ensure this value is within the range of an int32
if input[i] > int64(MaxInt64) || input[i] < int64(MinInt64) {
samliok marked this conversation as resolved.
Show resolved Hide resolved
return nil, fmt.Errorf("%w: %d", ErrIntegerConversionOverflow, input[i])
}
params[i] = int32(input[i])
case wasmtime.KindI64:
params[i] = int64(input[i])
params[i] = input[i]
default:
return nil, fmt.Errorf("%w: %v", ErrInvalidParamType, v.Kind())
}
Expand Down
6 changes: 3 additions & 3 deletions x/programs/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ func TestCallParams(t *testing.T) {
err = runtime.Initialize(ctx, wasm, maxUnits)
require.NoError(err)

resp, err := runtime.Call(ctx, "add", uint64(10), uint64(10))
resp, err := runtime.Call(ctx, "add", 10, 10)
require.NoError(err)
require.Equal(uint64(20), resp[0])
require.Equal(int64(20), resp[0])

// pass 3 params when 2 are expected.
_, err = runtime.Call(ctx, "add", uint64(10), uint64(10), uint64(10))
_, err = runtime.Call(ctx, "add", 10, 10, 10)
require.ErrorIs(err, ErrInvalidParamCount)
}
Loading