Skip to content

Commit

Permalink
Keep track of pointers and their length with a custom store (#857)
Browse files Browse the repository at this point in the history
* implement new alloc fun with global store

* make sure that only known bounds are reached

* use a map for the store

* remove use of HostPtr

* return pointer from go alloc

* apply review suggestions

* Fix the counter tests

* fix token test

* remove usage of SmartPtr in test

* use is_null to check for pointer nulleness

* silent deprecated tag

* fix clippy in test

* fix panic *expected* messages

* remove useless test

---------

Co-authored-by: Richard Pringle <[email protected]>
  • Loading branch information
iFrostizz and richardpringle authored Apr 25, 2024
1 parent 897be9b commit b7f3391
Show file tree
Hide file tree
Showing 16 changed files with 221 additions and 117 deletions.
18 changes: 9 additions & 9 deletions x/programs/cmd/simulator/vm/actions/program_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,32 +175,32 @@ type CallParam struct {

// WriteParams is a helper function that writes the given params to memory if non integer.
// Supported types include int, uint64 and string.
func WriteParams(m *program.Memory, p []CallParam) ([]program.SmartPtr, error) {
var params []program.SmartPtr
func WriteParams(m *program.Memory, p []CallParam) ([]uint32, error) {
var params []uint32
for _, param := range p {
switch v := param.Value.(type) {
case []byte:
smartPtr, err := program.BytesToSmartPtr(v, m)
smartPtr, err := program.AllocateBytes(v, m)
if err != nil {
return nil, err
}
params = append(params, smartPtr)
case ids.ID:
smartPtr, err := program.BytesToSmartPtr(v[:], m)
smartPtr, err := program.AllocateBytes(v[:], m)
if err != nil {
return nil, err
}
params = append(params, smartPtr)
case string:
smartPtr, err := program.BytesToSmartPtr([]byte(v), m)
smartPtr, err := program.AllocateBytes([]byte(v), m)
if err != nil {
return nil, err
}
params = append(params, smartPtr)
case program.SmartPtr:
case uint32:
params = append(params, v)
default:
ptr, err := argumentToSmartPtr(v, m)
ptr, err := writeToMem(v, m)
if err != nil {
return nil, err
}
Expand All @@ -218,11 +218,11 @@ func serializeParameter(obj interface{}) ([]byte, error) {
}

// Serialize the parameter and create a smart ptr
func argumentToSmartPtr(obj interface{}, memory *program.Memory) (program.SmartPtr, error) {
func writeToMem(obj interface{}, memory *program.Memory) (uint32, error) {
bytes, err := serializeParameter(obj)
if err != nil {
return 0, err
}

return program.BytesToSmartPtr(bytes, memory)
return program.AllocateBytes(bytes, memory)
}
39 changes: 32 additions & 7 deletions x/programs/examples/counter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,17 @@ func TestCounterProgram(t *testing.T) {
require.NoError(err)

// write alice's key to stack and get pointer
alicePtr, err := argumentToSmartPtr(alicePublicKey, mem)
alicePtr, err := writeToMem(alicePublicKey, mem)
require.NoError(err)

// create counter for alice on program 1
result, err := rt.Call(ctx, "initialize_address", callContext, alicePtr)
require.NoError(err)
require.Equal(int64(1), result[0])

alicePtr, err = writeToMem(alicePublicKey, mem)
require.NoError(err)

// validate counter at 0
result, err = rt.Call(ctx, "get_value", callContext, alicePtr)
require.NoError(err)
Expand Down Expand Up @@ -115,7 +118,7 @@ func TestCounterProgram(t *testing.T) {
require.NoError(err)

// write alice's key to stack and get pointer
alicePtr2, err := argumentToSmartPtr(alicePublicKey, mem2)
alicePtr2, err := writeToMem(alicePublicKey, mem2)
require.NoError(err)

callContext1 := program.Context{ProgramID: programID}
Expand All @@ -128,12 +131,19 @@ func TestCounterProgram(t *testing.T) {

// increment alice's counter on program 2 by 10
incAmount := int64(10)
incAmountPtr, err := argumentToSmartPtr(incAmount, mem2)
incAmountPtr, err := writeToMem(incAmount, mem2)
require.NoError(err)

alicePtr2, err = writeToMem(alicePublicKey, mem2)

require.NoError(err)
result, err = rt2.Call(ctx, "inc", callContext2, alicePtr2, incAmountPtr)
require.NoError(err)
require.Equal(int64(1), result[0])

alicePtr2, err = writeToMem(alicePublicKey, mem2)
require.NoError(err)

result, err = rt2.Call(ctx, "get_value", callContext2, alicePtr2)
require.NoError(err)
require.Equal(incAmount, result[0])
Expand All @@ -146,12 +156,19 @@ func TestCounterProgram(t *testing.T) {
require.NoError(err)

// increment alice's counter on program 1
onePtr, err := argumentToSmartPtr(int64(1), mem)
onePtr, err := writeToMem(int64(1), mem)
require.NoError(err)

alicePtr, err = writeToMem(alicePublicKey, mem)
require.NoError(err)

result, err = rt.Call(ctx, "inc", callContext1, alicePtr, onePtr)
require.NoError(err)
require.Equal(int64(1), result[0])

alicePtr, err = writeToMem(alicePublicKey, mem)
require.NoError(err)

result, err = rt.Call(ctx, "get_value", callContext1, alicePtr)
require.NoError(err)

Expand All @@ -160,20 +177,28 @@ func TestCounterProgram(t *testing.T) {
)

// write program id 2 to stack of program 1
target, err := argumentToSmartPtr(programID2, mem)
target, err := writeToMem(programID2, mem)
require.NoError(err)

maxUnitsProgramToProgram := int64(10000)
maxUnitsProgramToProgramPtr, err := argumentToSmartPtr(maxUnitsProgramToProgram, mem)
maxUnitsProgramToProgramPtr, err := writeToMem(maxUnitsProgramToProgram, mem)
require.NoError(err)

// increment alice's counter on program 2
fivePtr, err := argumentToSmartPtr(int64(5), mem)
fivePtr, err := writeToMem(int64(5), mem)
require.NoError(err)
alicePtr, err = writeToMem(alicePublicKey, mem)
require.NoError(err)
result, err = rt.Call(ctx, "inc_external", callContext1, target, maxUnitsProgramToProgramPtr, alicePtr, fivePtr)
require.NoError(err)
require.Equal(int64(1), result[0])

target, err = writeToMem(programID2, mem)
require.NoError(err)
alicePtr, err = writeToMem(alicePublicKey, mem)
require.NoError(err)
maxUnitsProgramToProgramPtr, err = writeToMem(maxUnitsProgramToProgram, mem)
require.NoError(err)
// expect alice's counter on program 2 to be 15
result, err = rt.Call(ctx, "get_value_external", callContext1, target, maxUnitsProgramToProgramPtr, alicePtr)
require.NoError(err)
Expand Down
10 changes: 3 additions & 7 deletions x/programs/examples/imports/program/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle
}

// getCallArgs returns the arguments to be passed to the program being invoked from [buffer].
func getCallArgs(_ context.Context, memory *program.Memory, buffer []byte) ([]program.SmartPtr, error) {
var args []program.SmartPtr
func getCallArgs(_ context.Context, memory *program.Memory, buffer []byte) ([]uint32, error) {
var args []uint32

for i := 0; i < len(buffer); {
// unpacks uint32
Expand All @@ -201,11 +201,7 @@ func getCallArgs(_ context.Context, memory *program.Memory, buffer []byte) ([]pr
if err != nil {
return nil, err
}
argPtr, err := program.NewSmartPtr(ptr, int(length))
if err != nil {
return nil, err
}
args = append(args, argPtr)
args = append(args, ptr)
}

return args, nil
Expand Down
62 changes: 56 additions & 6 deletions x/programs/examples/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (t *Token) Run(ctx context.Context) error {
}

// write alice's key to stack and get pointer
alicePtr, err := argumentToSmartPtr(alicePublicKey, mem)
alicePtr, err := writeToMem(alicePublicKey, mem)
if err != nil {
return err
}
Expand All @@ -146,7 +146,7 @@ func (t *Token) Run(ctx context.Context) error {
}

// write bob's key to stack and get pointer
bobPtr, err := argumentToSmartPtr(bobPublicKey, mem)
bobPtr, err := writeToMem(bobPublicKey, mem)
if err != nil {
return err
}
Expand All @@ -162,7 +162,7 @@ func (t *Token) Run(ctx context.Context) error {

// mint 100 tokens to alice
mintAlice := int64(1000)
mintAlicePtr, err := argumentToSmartPtr(mintAlice, mem)
mintAlicePtr, err := writeToMem(mintAlice, mem)
if err != nil {
return err
}
Expand All @@ -175,6 +175,11 @@ func (t *Token) Run(ctx context.Context) error {
zap.Int64("alice", mintAlice),
)

alicePtr, err = writeToMem(alicePublicKey, mem)
if err != nil {
return err
}

// check balance of alice
result, err = rt.Call(ctx, "get_balance", programContext, alicePtr)
if err != nil {
Expand All @@ -184,6 +189,11 @@ func (t *Token) Run(ctx context.Context) error {
zap.Int64("alice", result[0]),
)

bobPtr, err = writeToMem(bobPublicKey, mem)
if err != nil {
return err
}

// check balance of bob
result, err = rt.Call(ctx, "get_balance", programContext, bobPtr)
if err != nil {
Expand All @@ -195,10 +205,20 @@ func (t *Token) Run(ctx context.Context) error {

// transfer 50 from alice to bob
transferToBob := int64(50)
transferToBobPtr, err := argumentToSmartPtr(transferToBob, mem)
transferToBobPtr, err := writeToMem(transferToBob, mem)
if err != nil {
return err
}
bobPtr, err = writeToMem(bobPublicKey, mem)
if err != nil {
return err
}

alicePtr, err = writeToMem(alicePublicKey, mem)
if err != nil {
return err
}

_, err = rt.Call(ctx, "transfer", programContext, alicePtr, bobPtr, transferToBobPtr)
if err != nil {
return err
Expand All @@ -208,7 +228,17 @@ func (t *Token) Run(ctx context.Context) error {
zap.Int64("to bob", transferToBob),
)

onePtr, err := argumentToSmartPtr(int64(1), mem)
onePtr, err := writeToMem(int64(1), mem)
if err != nil {
return err
}

bobPtr, err = writeToMem(bobPublicKey, mem)
if err != nil {
return err
}

alicePtr, err = writeToMem(alicePublicKey, mem)
if err != nil {
return err
}
Expand All @@ -222,6 +252,11 @@ func (t *Token) Run(ctx context.Context) error {
zap.Int64("to bob", 1),
)

alicePtr, err = writeToMem(alicePublicKey, mem)
if err != nil {
return err
}

// get balance alice
result, err = rt.Call(ctx, "get_balance", programContext, alicePtr)
if err != nil {
Expand All @@ -231,6 +266,11 @@ func (t *Token) Run(ctx context.Context) error {
zap.Int64("alice", result[0]),
)

bobPtr, err = writeToMem(bobPublicKey, mem)
if err != nil {
return err
}

// get balance bob
result, err = rt.Call(ctx, "get_balance", programContext, bobPtr)
if err != nil {
Expand Down Expand Up @@ -259,7 +299,7 @@ func (t *Token) Run(ctx context.Context) error {
},
}

mintersPtr, err := argumentToSmartPtr(minters, mem)
mintersPtr, err := writeToMem(minters, mem)
if err != nil {
return err
}
Expand All @@ -274,6 +314,11 @@ func (t *Token) Run(ctx context.Context) error {
zap.Int32("to bob", minters[1].Amount),
)

alicePtr, err = writeToMem(alicePublicKey, mem)
if err != nil {
return err
}

// get balance alice
result, err = rt.Call(ctx, "get_balance", programContext, alicePtr)
if err != nil {
Expand All @@ -283,6 +328,11 @@ func (t *Token) Run(ctx context.Context) error {
zap.Int64("alice", result[0]),
)

bobPtr, err = writeToMem(bobPublicKey, mem)
if err != nil {
return err
}

// get balance bob
result, err = rt.Call(ctx, "get_balance", programContext, bobPtr)
if err != nil {
Expand Down
10 changes: 8 additions & 2 deletions x/programs/examples/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,20 @@ func TestTokenProgram(t *testing.T) {
require.NoError(err)

// write alice's key to stack and get pointer
alicePtr, err := argumentToSmartPtr(alicePublicKey, mem)
alicePtr, err := writeToMem(alicePublicKey, mem)
require.NoError(err)

// mint 100 tokens to alice
mintAlice := int64(1000)
mintAlicePtr, err := argumentToSmartPtr(mintAlice, mem)
mintAlicePtr, err := writeToMem(mintAlice, mem)
require.NoError(err)

_, err = rt.Call(ctx, "mint_to", callContext, alicePtr, mintAlicePtr)
require.NoError(err)

alicePtr, err = writeToMem(alicePublicKey, mem)
require.NoError(err)

// check balance of alice
result, err := rt.Call(ctx, "get_balance", callContext, alicePtr)
require.NoError(err)
Expand All @@ -75,6 +78,9 @@ func TestTokenProgram(t *testing.T) {
require.NoError(err)
require.Equal(int64(1000), aliceBalance)

alicePtr, err = writeToMem(alicePublicKey, mem)
require.NoError(err)

// burn alice tokens
_, err = rt.Call(ctx, "burn_from", callContext, alicePtr)
require.NoError(err)
Expand Down
4 changes: 2 additions & 2 deletions x/programs/examples/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ func newKey() (ed25519.PublicKey, error) {
}

// Serialize the parameter and create a smart ptr
func argumentToSmartPtr(obj interface{}, memory *program.Memory) (program.SmartPtr, error) {
func writeToMem(obj interface{}, memory *program.Memory) (uint32, error) {
bytes, err := borsh.Serialize(obj)
if err != nil {
return 0, err
}

return program.BytesToSmartPtr(bytes, memory)
return program.AllocateBytes(bytes, memory)
}

var (
Expand Down
Loading

0 comments on commit b7f3391

Please sign in to comment.