Skip to content

Commit

Permalink
Support custom object-params
Browse files Browse the repository at this point in the history
  • Loading branch information
magik6k committed Jan 23, 2023
1 parent 44a9f01 commit eb12710
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 63 deletions.
60 changes: 44 additions & 16 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,16 @@ func (c *client) setupRequestChan() chan clientRequest {
case <-ctxDone: // send cancel request
ctxDone = nil

rp, err := json.Marshal([]param{{v: reflect.ValueOf(cr.req.ID)}})
if err != nil {
return clientResponse{}, xerrors.Errorf("marshalling cancel request: %w", err)
}

cancelReq := clientRequest{
req: request{
Jsonrpc: "2.0",
Method: wsCancel,
Params: []param{{v: reflect.ValueOf(cr.req.ID)}},
Params: rp,
},
ready: make(chan clientResponse, 1),
}
Expand Down Expand Up @@ -452,7 +457,11 @@ type rpcFunc struct {
valOut int
errOut int

hasCtx int
// hasCtx is 1 if the function has a context.Context as its first argument.
// Used as the number of the first non-context argument.
hasCtx int

hasRawParams bool
returnValueIsChannel bool

retry bool
Expand Down Expand Up @@ -507,20 +516,31 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
}
}

params := make([]param, len(args)-fn.hasCtx)
for i, arg := range args[fn.hasCtx:] {
enc, found := fn.client.paramEncoders[arg.Type()]
if found {
// custom param encoder
var err error
arg, err = enc(arg)
if err != nil {
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
var serializedParams json.RawMessage

if fn.hasRawParams {
serializedParams = json.RawMessage(args[fn.hasCtx].Interface().(RawParams))
} else {
params := make([]param, len(args)-fn.hasCtx)
for i, arg := range args[fn.hasCtx:] {
enc, found := fn.client.paramEncoders[arg.Type()]
if found {
// custom param encoder
var err error
arg, err = enc(arg)
if err != nil {
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
}
}
}

params[i] = param{
v: arg,
params[i] = param{
v: arg,
}
}
var err error
serializedParams, err = json.Marshal(params)
if err != nil {
return fn.processError(fmt.Errorf("marshaling params failed: %w", err))
}
}

Expand All @@ -545,7 +565,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
Jsonrpc: "2.0",
ID: id,
Method: fn.name,
Params: params,
Params: serializedParams,
}

if span != nil {
Expand Down Expand Up @@ -631,10 +651,18 @@ func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) {
return reflect.Value{}, xerrors.New("notify methods cannot return values")
}

fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan

if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
fun.hasCtx = 1
}
fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan
// note: hasCtx is also the number of the first non-context argument
if ftyp.NumIn() > fun.hasCtx && ftyp.In(fun.hasCtx) == rtRawParams {
if ftyp.NumIn() > fun.hasCtx+1 {
return reflect.Value{}, xerrors.New("raw params can't be mixed with other arguments")
}
fun.hasRawParams = true
}

return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil
}
104 changes: 72 additions & 32 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ import (
"github.com/filecoin-project/go-jsonrpc/metrics"
)

type RawParams json.RawMessage

var rtRawParams = reflect.TypeOf(RawParams{})

// todo is there a better way to tell 'struct with any number of fields'?
func DecodeParams[T any](p RawParams) (T, error) {
var t T
err := json.Unmarshal(p, &t)

// todo also handle list-encoding automagically (json.Unmarshal doesn't do that, does it?)

return t, err
}

// methodHandler is a handler for a single method
type methodHandler struct {
paramReceivers []reflect.Type
Expand All @@ -28,7 +42,8 @@ type methodHandler struct {
receiver reflect.Value
handlerFunc reflect.Value

hasCtx int
hasCtx int
hasRawParams bool

errOut int
valOut int
Expand All @@ -40,7 +55,7 @@ type request struct {
Jsonrpc string `json:"jsonrpc"`
ID interface{} `json:"id,omitempty"`
Method string `json:"method"`
Params []param `json:"params"`
Params json.RawMessage `json:"params"`
Meta map[string]string `json:"meta,omitempty"`
}

Expand Down Expand Up @@ -135,9 +150,16 @@ func (s *handler) register(namespace string, r interface{}) {
hasCtx = 1
}

hasRawParams := false
ins := funcType.NumIn() - 1 - hasCtx
recvs := make([]reflect.Type, ins)
for i := 0; i < ins; i++ {
if hasRawParams && i > 0 {
panic("raw params must be the last parameter")
}
if funcType.In(i+1+hasCtx) == rtRawParams {
hasRawParams = true
}
recvs[i] = method.Type.In(i + 1 + hasCtx)
}

Expand All @@ -150,7 +172,8 @@ func (s *handler) register(namespace string, r interface{}) {
handlerFunc: method.Func,
receiver: val,

hasCtx: hasCtx,
hasCtx: hasCtx,
hasRawParams: hasRawParams,

errOut: errOut,
valOut: valOut,
Expand Down Expand Up @@ -291,13 +314,6 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer
}
}

if len(req.Params) != handler.nParams {
rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(req.Params), handler.nParams))
stats.Record(ctx, metrics.RPCRequestError.M(1))
done(false)
return
}

outCh := handler.valOut != -1 && handler.handlerFunc.Type().Out(handler.valOut).Kind() == reflect.Chan
defer done(outCh)

Expand All @@ -313,30 +329,54 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer
callParams[1] = reflect.ValueOf(ctx)
}

for i := 0; i < handler.nParams; i++ {
var rp reflect.Value

typ := handler.paramReceivers[i]
dec, found := s.paramDecoders[typ]
if !found {
rp = reflect.New(typ)
if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil {
rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
rp = rp.Elem()
} else {
var err error
rp, err = dec(ctx, req.Params[i].data)
if err != nil {
rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
if handler.hasRawParams {
// When hasRawParams is true, there is only one parameter and it is a
// json.RawMessage.

callParams[1+handler.hasCtx] = reflect.ValueOf(RawParams(req.Params))
} else {
// "normal" param list; no good way to do named params in Golang

var ps []param
err := json.Unmarshal(req.Params, &ps)
if err != nil {
rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling param array: %w", err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}

if len(ps) != handler.nParams {
rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(ps), handler.nParams))
stats.Record(ctx, metrics.RPCRequestError.M(1))
done(false)
return
}

callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Interface())
for i := 0; i < handler.nParams; i++ {
var rp reflect.Value

typ := handler.paramReceivers[i]
dec, found := s.paramDecoders[typ]
if !found {
rp = reflect.New(typ)
if err := json.NewDecoder(bytes.NewReader(ps[i].data)).Decode(rp.Interface()); err != nil {
rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
rp = rp.Elem()
} else {
var err error
rp, err = dec(ctx, ps[i].data)
if err != nil {
rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err))
stats.Record(ctx, metrics.RPCRequestError.M(1))
return
}
}

callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Interface())
}
}

// /////////////////
Expand Down
49 changes: 43 additions & 6 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1220,12 +1220,49 @@ func TestNotif(t *testing.T) {
t.Run("http", tc("http"))
}

// 1. make server call on client **
// 2. make client handle **
// 3. alias on client **
// 4. alias call on server **
// 6. custom/object param type
// 7. notif mode proxy tag
type RawParamHandler struct {
}

type CustomParams struct {
I int
}

func (h *RawParamHandler) Call(ctx context.Context, ps RawParams) (int, error) {
p, err := DecodeParams[CustomParams](ps)
if err != nil {
return 0, err
}
return p.I + 1, nil
}

func TestCallWithRawParams(t *testing.T) {
// setup server

rpcServer := NewServer()
rpcServer.Register("Raw", &RawParamHandler{})

// httptest stuff
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()

// setup client
var client struct {
Call func(ctx context.Context, ps RawParams) (int, error)
}
closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Raw", []interface{}{
&client,
}, nil)
require.NoError(t, err)

// do the call!

// this will block if it's not sent as a notification
n, err := client.Call(context.Background(), []byte(`{"I": 1}`))
require.NoError(t, err)
require.Equal(t, 2, n)

closer()
}

type RevCallTestServerHandler struct {
}
Expand Down
Loading

0 comments on commit eb12710

Please sign in to comment.