From 55db897aab533d2f00f444685aad6ccb15674f76 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 14 Jun 2023 10:10:55 +0200 Subject: [PATCH] fix(state): JSON marshaling for sdk Address wrapper (#2348) While debugging #2322 , I noticed that it is not only an RPC issue, but also a broken API one. All RPC methods that used the type `state.Address` were unusable, both via the client and via raw JSON calls. This is because the server was unable to marshal the address string value back into the interface type. To circumvent this, I have embedded the sdk.Address type in the same way that we embed the fraud proof type, to allow us to unmarshal it back into a concrete type. I have also added unit tests to fix this. In addition, it fixes the RPC parsing - so it closes #2322 . --- api/docgen/examples.go | 2 ++ api/gateway/state.go | 2 +- cmd/celestia/rpc.go | 16 +++------- nodebuilder/state/mocks/api.go | 7 +++-- state/address_test.go | 57 ++++++++++++++++++++++++++++++++++ state/core_access.go | 6 ++-- state/integration_test.go | 2 +- state/state.go | 34 ++++++++++++++++++-- 8 files changed, 104 insertions(+), 22 deletions(-) create mode 100644 state/address_test.go diff --git a/api/docgen/examples.go b/api/docgen/examples.go index b19b06a7f8..80a8c64d93 100644 --- a/api/docgen/examples.go +++ b/api/docgen/examples.go @@ -83,6 +83,8 @@ func init() { } addToExampleValues(valAddr) + addToExampleValues(state.Address{Address: addr}) + var txResponse *state.TxResponse err = json.Unmarshal([]byte(exampleTxResponse), &txResponse) if err != nil { diff --git a/api/gateway/state.go b/api/gateway/state.go index 63ada8ebb2..c895e35bde 100644 --- a/api/gateway/state.go +++ b/api/gateway/state.go @@ -72,7 +72,7 @@ func (h *Handler) handleBalanceRequest(w http.ResponseWriter, r *http.Request) { } addr = valAddr.Bytes() } - bal, err = h.state.BalanceForAddress(r.Context(), addr) + bal, err = h.state.BalanceForAddress(r.Context(), state.Address{Address: addr}) } else { bal, err = h.state.Balance(r.Context()) } diff --git a/cmd/celestia/rpc.go b/cmd/celestia/rpc.go index de8c55e3f4..767fca872c 100644 --- a/cmd/celestia/rpc.go +++ b/cmd/celestia/rpc.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" - "errors" "fmt" "io" "log" @@ -15,7 +14,6 @@ import ( "strconv" "strings" - "github.com/cosmos/cosmos-sdk/types" "github.com/spf13/cobra" "github.com/celestiaorg/nmt/namespace" @@ -392,18 +390,12 @@ func sendJSONRPCRequest(namespace, method string, params []interface{}) { } func parseAddressFromString(addrStr string) (state.Address, error) { - var addr state.AccAddress - addr, err := types.AccAddressFromBech32(addrStr) + var address state.Address + err := address.UnmarshalJSON([]byte(addrStr)) if err != nil { - // first check if it is a validator address and can be converted - valAddr, err := types.ValAddressFromBech32(addrStr) - if err != nil { - return nil, errors.New("address must be a valid account or validator address ") - } - return valAddr, nil + return address, err } - - return addr, nil + return address, nil } func parseSignatureForHelpstring(methodSig reflect.StructField) string { diff --git a/nodebuilder/state/mocks/api.go b/nodebuilder/state/mocks/api.go index 814b9a0113..dbd1d5dabe 100644 --- a/nodebuilder/state/mocks/api.go +++ b/nodebuilder/state/mocks/api.go @@ -10,6 +10,7 @@ import ( math "cosmossdk.io/math" blob "github.com/celestiaorg/celestia-node/blob" + state "github.com/celestiaorg/celestia-node/state" types "github.com/cosmos/cosmos-sdk/types" types0 "github.com/cosmos/cosmos-sdk/x/staking/types" gomock "github.com/golang/mock/gomock" @@ -40,10 +41,10 @@ func (m *MockModule) EXPECT() *MockModuleMockRecorder { } // AccountAddress mocks base method. -func (m *MockModule) AccountAddress(arg0 context.Context) (types.Address, error) { +func (m *MockModule) AccountAddress(arg0 context.Context) (state.Address, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AccountAddress", arg0) - ret0, _ := ret[0].(types.Address) + ret0, _ := ret[0].(state.Address) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -70,7 +71,7 @@ func (mr *MockModuleMockRecorder) Balance(arg0 interface{}) *gomock.Call { } // BalanceForAddress mocks base method. -func (m *MockModule) BalanceForAddress(arg0 context.Context, arg1 types.Address) (*types.Coin, error) { +func (m *MockModule) BalanceForAddress(arg0 context.Context, arg1 state.Address) (*types.Coin, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BalanceForAddress", arg0, arg1) ret0, _ := ret[0].(*types.Coin) diff --git a/state/address_test.go b/state/address_test.go new file mode 100644 index 0000000000..d701b38aa8 --- /dev/null +++ b/state/address_test.go @@ -0,0 +1,57 @@ +package state + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddressMarshalling(t *testing.T) { + testCases := []struct { + name string + addressString string + addressFromStr func(string) (interface{}, error) + marshalJSON func(interface{}) ([]byte, error) + unmarshalJSON func([]byte) (interface{}, error) + }{ + { + name: "Account Address", + addressString: "celestia1377k5an3f94v6wyaceu0cf4nq6gk2jtpc46g7h", + addressFromStr: func(s string) (interface{}, error) { return sdk.AccAddressFromBech32(s) }, + marshalJSON: func(addr interface{}) ([]byte, error) { return addr.(sdk.AccAddress).MarshalJSON() }, + unmarshalJSON: func(b []byte) (interface{}, error) { + var addr sdk.AccAddress + err := addr.UnmarshalJSON(b) + return addr, err + }, + }, + { + name: "Validator Address", + addressString: "celestiavaloper1q3v5cugc8cdpud87u4zwy0a74uxkk6u4q4gx4p", + addressFromStr: func(s string) (interface{}, error) { return sdk.ValAddressFromBech32(s) }, + marshalJSON: func(addr interface{}) ([]byte, error) { return addr.(sdk.ValAddress).MarshalJSON() }, + unmarshalJSON: func(b []byte) (interface{}, error) { + var addr sdk.ValAddress + err := addr.UnmarshalJSON(b) + return addr, err + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + addr, err := tc.addressFromStr(tc.addressString) + require.NoError(t, err) + + addrBytes, err := tc.marshalJSON(addr) + assert.NoError(t, err) + assert.Equal(t, []byte("\""+tc.addressString+"\""), addrBytes) + + addrUnmarshalled, err := tc.unmarshalJSON(addrBytes) + assert.NoError(t, err) + assert.Equal(t, addr, addrUnmarshalled) + }) + } +} diff --git a/state/core_access.go b/state/core_access.go index ca0947166d..7b59f3e714 100644 --- a/state/core_access.go +++ b/state/core_access.go @@ -194,9 +194,9 @@ func (ca *CoreAccessor) SubmitPayForBlob( func (ca *CoreAccessor) AccountAddress(context.Context) (Address, error) { addr, err := ca.signer.GetSignerInfo().GetAddress() if err != nil { - return nil, err + return Address{nil}, err } - return addr, nil + return Address{addr}, nil } func (ca *CoreAccessor) Balance(ctx context.Context) (*Balance, error) { @@ -204,7 +204,7 @@ func (ca *CoreAccessor) Balance(ctx context.Context) (*Balance, error) { if err != nil { return nil, err } - return ca.BalanceForAddress(ctx, addr) + return ca.BalanceForAddress(ctx, Address{addr}) } func (ca *CoreAccessor) BalanceForAddress(ctx context.Context, addr Address) (*Balance, error) { diff --git a/state/integration_test.go b/state/integration_test.go index e7d2496397..8862de1bf8 100644 --- a/state/integration_test.go +++ b/state/integration_test.go @@ -110,7 +110,7 @@ func (s *IntegrationTestSuite) TestGetBalance() { require := s.Require() expectedBal := sdk.NewCoin(app.BondDenom, sdk.NewInt(int64(99999999999999999))) for _, acc := range s.accounts { - bal, err := s.accessor.BalanceForAddress(context.Background(), s.getAddress(acc)) + bal, err := s.accessor.BalanceForAddress(context.Background(), Address{s.getAddress(acc)}) require.NoError(err) require.Equal(&expectedBal, bal) } diff --git a/state/state.go b/state/state.go index 987a783239..d55bb6901c 100644 --- a/state/state.go +++ b/state/state.go @@ -1,6 +1,9 @@ package state import ( + "fmt" + "strings" + "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" coretypes "github.com/tendermint/tendermint/types" @@ -15,8 +18,11 @@ type Tx = coretypes.Tx // TxResponse is an alias to the TxResponse type from Cosmos-SDK. type TxResponse = sdk.TxResponse -// Address is an alias to the Address type from Cosmos-SDK. -type Address = sdk.Address +// Address is an alias to the Address type from Cosmos-SDK. It is embedded into a struct to provide +// a non-interface type for JSON serialization. +type Address struct { + sdk.Address +} // ValAddress is an alias to the ValAddress type from Cosmos-SDK. type ValAddress = sdk.ValAddress @@ -26,3 +32,27 @@ type AccAddress = sdk.AccAddress // Int is an alias to the Int type from Cosmos-SDK. type Int = math.Int + +func (a *Address) UnmarshalJSON(data []byte) error { + // To convert the string back to a concrete type, we have to determine the correct implementation + var addr AccAddress + addrString := strings.Trim(string(data), "\"") + addr, err := sdk.AccAddressFromBech32(addrString) + if err != nil { + // first check if it is a validator address and can be converted + valAddr, err := sdk.ValAddressFromBech32(addrString) + if err != nil { + return fmt.Errorf("address must be a valid account or validator address: %w", err) + } + a.Address = valAddr + return nil + } + + a.Address = addr + return nil +} + +func (a Address) MarshalJSON() ([]byte, error) { + // The address is marshaled into a simple string value + return []byte("\"" + a.Address.String() + "\""), nil +}