diff --git a/api/docgen/examples.go b/api/docgen/examples.go index b19b06a7f8..80a8c64d93 100644 --- a/api/docgen/examples.go +++ b/api/docgen/examples.go @@ -83,6 +83,8 @@ func init() { } addToExampleValues(valAddr) + addToExampleValues(state.Address{Address: addr}) + var txResponse *state.TxResponse err = json.Unmarshal([]byte(exampleTxResponse), &txResponse) if err != nil { diff --git a/api/gateway/state.go b/api/gateway/state.go index 63ada8ebb2..c895e35bde 100644 --- a/api/gateway/state.go +++ b/api/gateway/state.go @@ -72,7 +72,7 @@ func (h *Handler) handleBalanceRequest(w http.ResponseWriter, r *http.Request) { } addr = valAddr.Bytes() } - bal, err = h.state.BalanceForAddress(r.Context(), addr) + bal, err = h.state.BalanceForAddress(r.Context(), state.Address{Address: addr}) } else { bal, err = h.state.Balance(r.Context()) } diff --git a/cmd/celestia/rpc.go b/cmd/celestia/rpc.go index de8c55e3f4..767fca872c 100644 --- a/cmd/celestia/rpc.go +++ b/cmd/celestia/rpc.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" - "errors" "fmt" "io" "log" @@ -15,7 +14,6 @@ import ( "strconv" "strings" - "github.com/cosmos/cosmos-sdk/types" "github.com/spf13/cobra" "github.com/celestiaorg/nmt/namespace" @@ -392,18 +390,12 @@ func sendJSONRPCRequest(namespace, method string, params []interface{}) { } func parseAddressFromString(addrStr string) (state.Address, error) { - var addr state.AccAddress - addr, err := types.AccAddressFromBech32(addrStr) + var address state.Address + err := address.UnmarshalJSON([]byte(addrStr)) if err != nil { - // first check if it is a validator address and can be converted - valAddr, err := types.ValAddressFromBech32(addrStr) - if err != nil { - return nil, errors.New("address must be a valid account or validator address ") - } - return valAddr, nil + return address, err } - - return addr, nil + return address, nil } func parseSignatureForHelpstring(methodSig reflect.StructField) string { diff --git a/nodebuilder/state/mocks/api.go b/nodebuilder/state/mocks/api.go index 814b9a0113..dbd1d5dabe 100644 --- a/nodebuilder/state/mocks/api.go +++ b/nodebuilder/state/mocks/api.go @@ -10,6 +10,7 @@ import ( math "cosmossdk.io/math" blob "github.com/celestiaorg/celestia-node/blob" + state "github.com/celestiaorg/celestia-node/state" types "github.com/cosmos/cosmos-sdk/types" types0 "github.com/cosmos/cosmos-sdk/x/staking/types" gomock "github.com/golang/mock/gomock" @@ -40,10 +41,10 @@ func (m *MockModule) EXPECT() *MockModuleMockRecorder { } // AccountAddress mocks base method. -func (m *MockModule) AccountAddress(arg0 context.Context) (types.Address, error) { +func (m *MockModule) AccountAddress(arg0 context.Context) (state.Address, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AccountAddress", arg0) - ret0, _ := ret[0].(types.Address) + ret0, _ := ret[0].(state.Address) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -70,7 +71,7 @@ func (mr *MockModuleMockRecorder) Balance(arg0 interface{}) *gomock.Call { } // BalanceForAddress mocks base method. -func (m *MockModule) BalanceForAddress(arg0 context.Context, arg1 types.Address) (*types.Coin, error) { +func (m *MockModule) BalanceForAddress(arg0 context.Context, arg1 state.Address) (*types.Coin, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BalanceForAddress", arg0, arg1) ret0, _ := ret[0].(*types.Coin) diff --git a/state/address_test.go b/state/address_test.go new file mode 100644 index 0000000000..d701b38aa8 --- /dev/null +++ b/state/address_test.go @@ -0,0 +1,57 @@ +package state + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddressMarshalling(t *testing.T) { + testCases := []struct { + name string + addressString string + addressFromStr func(string) (interface{}, error) + marshalJSON func(interface{}) ([]byte, error) + unmarshalJSON func([]byte) (interface{}, error) + }{ + { + name: "Account Address", + addressString: "celestia1377k5an3f94v6wyaceu0cf4nq6gk2jtpc46g7h", + addressFromStr: func(s string) (interface{}, error) { return sdk.AccAddressFromBech32(s) }, + marshalJSON: func(addr interface{}) ([]byte, error) { return addr.(sdk.AccAddress).MarshalJSON() }, + unmarshalJSON: func(b []byte) (interface{}, error) { + var addr sdk.AccAddress + err := addr.UnmarshalJSON(b) + return addr, err + }, + }, + { + name: "Validator Address", + addressString: "celestiavaloper1q3v5cugc8cdpud87u4zwy0a74uxkk6u4q4gx4p", + addressFromStr: func(s string) (interface{}, error) { return sdk.ValAddressFromBech32(s) }, + marshalJSON: func(addr interface{}) ([]byte, error) { return addr.(sdk.ValAddress).MarshalJSON() }, + unmarshalJSON: func(b []byte) (interface{}, error) { + var addr sdk.ValAddress + err := addr.UnmarshalJSON(b) + return addr, err + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + addr, err := tc.addressFromStr(tc.addressString) + require.NoError(t, err) + + addrBytes, err := tc.marshalJSON(addr) + assert.NoError(t, err) + assert.Equal(t, []byte("\""+tc.addressString+"\""), addrBytes) + + addrUnmarshalled, err := tc.unmarshalJSON(addrBytes) + assert.NoError(t, err) + assert.Equal(t, addr, addrUnmarshalled) + }) + } +} diff --git a/state/core_access.go b/state/core_access.go index ca0947166d..7b59f3e714 100644 --- a/state/core_access.go +++ b/state/core_access.go @@ -194,9 +194,9 @@ func (ca *CoreAccessor) SubmitPayForBlob( func (ca *CoreAccessor) AccountAddress(context.Context) (Address, error) { addr, err := ca.signer.GetSignerInfo().GetAddress() if err != nil { - return nil, err + return Address{nil}, err } - return addr, nil + return Address{addr}, nil } func (ca *CoreAccessor) Balance(ctx context.Context) (*Balance, error) { @@ -204,7 +204,7 @@ func (ca *CoreAccessor) Balance(ctx context.Context) (*Balance, error) { if err != nil { return nil, err } - return ca.BalanceForAddress(ctx, addr) + return ca.BalanceForAddress(ctx, Address{addr}) } func (ca *CoreAccessor) BalanceForAddress(ctx context.Context, addr Address) (*Balance, error) { diff --git a/state/integration_test.go b/state/integration_test.go index e7d2496397..8862de1bf8 100644 --- a/state/integration_test.go +++ b/state/integration_test.go @@ -110,7 +110,7 @@ func (s *IntegrationTestSuite) TestGetBalance() { require := s.Require() expectedBal := sdk.NewCoin(app.BondDenom, sdk.NewInt(int64(99999999999999999))) for _, acc := range s.accounts { - bal, err := s.accessor.BalanceForAddress(context.Background(), s.getAddress(acc)) + bal, err := s.accessor.BalanceForAddress(context.Background(), Address{s.getAddress(acc)}) require.NoError(err) require.Equal(&expectedBal, bal) } diff --git a/state/state.go b/state/state.go index 987a783239..d55bb6901c 100644 --- a/state/state.go +++ b/state/state.go @@ -1,6 +1,9 @@ package state import ( + "fmt" + "strings" + "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" coretypes "github.com/tendermint/tendermint/types" @@ -15,8 +18,11 @@ type Tx = coretypes.Tx // TxResponse is an alias to the TxResponse type from Cosmos-SDK. type TxResponse = sdk.TxResponse -// Address is an alias to the Address type from Cosmos-SDK. -type Address = sdk.Address +// Address is an alias to the Address type from Cosmos-SDK. It is embedded into a struct to provide +// a non-interface type for JSON serialization. +type Address struct { + sdk.Address +} // ValAddress is an alias to the ValAddress type from Cosmos-SDK. type ValAddress = sdk.ValAddress @@ -26,3 +32,27 @@ type AccAddress = sdk.AccAddress // Int is an alias to the Int type from Cosmos-SDK. type Int = math.Int + +func (a *Address) UnmarshalJSON(data []byte) error { + // To convert the string back to a concrete type, we have to determine the correct implementation + var addr AccAddress + addrString := strings.Trim(string(data), "\"") + addr, err := sdk.AccAddressFromBech32(addrString) + if err != nil { + // first check if it is a validator address and can be converted + valAddr, err := sdk.ValAddressFromBech32(addrString) + if err != nil { + return fmt.Errorf("address must be a valid account or validator address: %w", err) + } + a.Address = valAddr + return nil + } + + a.Address = addr + return nil +} + +func (a Address) MarshalJSON() ([]byte, error) { + // The address is marshaled into a simple string value + return []byte("\"" + a.Address.String() + "\""), nil +}