From eb1271087152468da8a246537b494a68503baf70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Magiera?= Date: Mon, 9 Jan 2023 17:17:00 +0100 Subject: [PATCH] Support custom object-params --- client.go | 60 +++++++++++++++++++++-------- handler.go | 104 +++++++++++++++++++++++++++++++++++---------------- rpc_test.go | 49 +++++++++++++++++++++--- websocket.go | 54 +++++++++++++++++++++----- 4 files changed, 204 insertions(+), 63 deletions(-) diff --git a/client.go b/client.go index 62fee15..13fd2ea 100644 --- a/client.go +++ b/client.go @@ -286,11 +286,16 @@ func (c *client) setupRequestChan() chan clientRequest { case <-ctxDone: // send cancel request ctxDone = nil + rp, err := json.Marshal([]param{{v: reflect.ValueOf(cr.req.ID)}}) + if err != nil { + return clientResponse{}, xerrors.Errorf("marshalling cancel request: %w", err) + } + cancelReq := clientRequest{ req: request{ Jsonrpc: "2.0", Method: wsCancel, - Params: []param{{v: reflect.ValueOf(cr.req.ID)}}, + Params: rp, }, ready: make(chan clientResponse, 1), } @@ -452,7 +457,11 @@ type rpcFunc struct { valOut int errOut int - hasCtx int + // hasCtx is 1 if the function has a context.Context as its first argument. + // Used as the number of the first non-context argument. + hasCtx int + + hasRawParams bool returnValueIsChannel bool retry bool @@ -507,20 +516,31 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value) } } - params := make([]param, len(args)-fn.hasCtx) - for i, arg := range args[fn.hasCtx:] { - enc, found := fn.client.paramEncoders[arg.Type()] - if found { - // custom param encoder - var err error - arg, err = enc(arg) - if err != nil { - return fn.processError(fmt.Errorf("sendRequest failed: %w", err)) + var serializedParams json.RawMessage + + if fn.hasRawParams { + serializedParams = json.RawMessage(args[fn.hasCtx].Interface().(RawParams)) + } else { + params := make([]param, len(args)-fn.hasCtx) + for i, arg := range args[fn.hasCtx:] { + enc, found := fn.client.paramEncoders[arg.Type()] + if found { + // custom param encoder + var err error + arg, err = enc(arg) + if err != nil { + return fn.processError(fmt.Errorf("sendRequest failed: %w", err)) + } } - } - params[i] = param{ - v: arg, + params[i] = param{ + v: arg, + } + } + var err error + serializedParams, err = json.Marshal(params) + if err != nil { + return fn.processError(fmt.Errorf("marshaling params failed: %w", err)) } } @@ -545,7 +565,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value) Jsonrpc: "2.0", ID: id, Method: fn.name, - Params: params, + Params: serializedParams, } if span != nil { @@ -631,10 +651,18 @@ func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) { return reflect.Value{}, xerrors.New("notify methods cannot return values") } + fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan + if ftyp.NumIn() > 0 && ftyp.In(0) == contextType { fun.hasCtx = 1 } - fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan + // note: hasCtx is also the number of the first non-context argument + if ftyp.NumIn() > fun.hasCtx && ftyp.In(fun.hasCtx) == rtRawParams { + if ftyp.NumIn() > fun.hasCtx+1 { + return reflect.Value{}, xerrors.New("raw params can't be mixed with other arguments") + } + fun.hasRawParams = true + } return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil } diff --git a/handler.go b/handler.go index 906581c..da1e689 100644 --- a/handler.go +++ b/handler.go @@ -20,6 +20,20 @@ import ( "github.com/filecoin-project/go-jsonrpc/metrics" ) +type RawParams json.RawMessage + +var rtRawParams = reflect.TypeOf(RawParams{}) + +// todo is there a better way to tell 'struct with any number of fields'? +func DecodeParams[T any](p RawParams) (T, error) { + var t T + err := json.Unmarshal(p, &t) + + // todo also handle list-encoding automagically (json.Unmarshal doesn't do that, does it?) + + return t, err +} + // methodHandler is a handler for a single method type methodHandler struct { paramReceivers []reflect.Type @@ -28,7 +42,8 @@ type methodHandler struct { receiver reflect.Value handlerFunc reflect.Value - hasCtx int + hasCtx int + hasRawParams bool errOut int valOut int @@ -40,7 +55,7 @@ type request struct { Jsonrpc string `json:"jsonrpc"` ID interface{} `json:"id,omitempty"` Method string `json:"method"` - Params []param `json:"params"` + Params json.RawMessage `json:"params"` Meta map[string]string `json:"meta,omitempty"` } @@ -135,9 +150,16 @@ func (s *handler) register(namespace string, r interface{}) { hasCtx = 1 } + hasRawParams := false ins := funcType.NumIn() - 1 - hasCtx recvs := make([]reflect.Type, ins) for i := 0; i < ins; i++ { + if hasRawParams && i > 0 { + panic("raw params must be the last parameter") + } + if funcType.In(i+1+hasCtx) == rtRawParams { + hasRawParams = true + } recvs[i] = method.Type.In(i + 1 + hasCtx) } @@ -150,7 +172,8 @@ func (s *handler) register(namespace string, r interface{}) { handlerFunc: method.Func, receiver: val, - hasCtx: hasCtx, + hasCtx: hasCtx, + hasRawParams: hasRawParams, errOut: errOut, valOut: valOut, @@ -291,13 +314,6 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer } } - if len(req.Params) != handler.nParams { - rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(req.Params), handler.nParams)) - stats.Record(ctx, metrics.RPCRequestError.M(1)) - done(false) - return - } - outCh := handler.valOut != -1 && handler.handlerFunc.Type().Out(handler.valOut).Kind() == reflect.Chan defer done(outCh) @@ -313,30 +329,54 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer callParams[1] = reflect.ValueOf(ctx) } - for i := 0; i < handler.nParams; i++ { - var rp reflect.Value - - typ := handler.paramReceivers[i] - dec, found := s.paramDecoders[typ] - if !found { - rp = reflect.New(typ) - if err := json.NewDecoder(bytes.NewReader(req.Params[i].data)).Decode(rp.Interface()); err != nil { - rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err)) - stats.Record(ctx, metrics.RPCRequestError.M(1)) - return - } - rp = rp.Elem() - } else { - var err error - rp, err = dec(ctx, req.Params[i].data) - if err != nil { - rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err)) - stats.Record(ctx, metrics.RPCRequestError.M(1)) - return - } + if handler.hasRawParams { + // When hasRawParams is true, there is only one parameter and it is a + // json.RawMessage. + + callParams[1+handler.hasCtx] = reflect.ValueOf(RawParams(req.Params)) + } else { + // "normal" param list; no good way to do named params in Golang + + var ps []param + err := json.Unmarshal(req.Params, &ps) + if err != nil { + rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling param array: %w", err)) + stats.Record(ctx, metrics.RPCRequestError.M(1)) + return + } + + if len(ps) != handler.nParams { + rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(ps), handler.nParams)) + stats.Record(ctx, metrics.RPCRequestError.M(1)) + done(false) + return } - callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Interface()) + for i := 0; i < handler.nParams; i++ { + var rp reflect.Value + + typ := handler.paramReceivers[i] + dec, found := s.paramDecoders[typ] + if !found { + rp = reflect.New(typ) + if err := json.NewDecoder(bytes.NewReader(ps[i].data)).Decode(rp.Interface()); err != nil { + rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err)) + stats.Record(ctx, metrics.RPCRequestError.M(1)) + return + } + rp = rp.Elem() + } else { + var err error + rp, err = dec(ctx, ps[i].data) + if err != nil { + rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err)) + stats.Record(ctx, metrics.RPCRequestError.M(1)) + return + } + } + + callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Interface()) + } } // ///////////////// diff --git a/rpc_test.go b/rpc_test.go index 0e0ef70..207d3a0 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -1220,12 +1220,49 @@ func TestNotif(t *testing.T) { t.Run("http", tc("http")) } -// 1. make server call on client ** -// 2. make client handle ** -// 3. alias on client ** -// 4. alias call on server ** -// 6. custom/object param type -// 7. notif mode proxy tag +type RawParamHandler struct { +} + +type CustomParams struct { + I int +} + +func (h *RawParamHandler) Call(ctx context.Context, ps RawParams) (int, error) { + p, err := DecodeParams[CustomParams](ps) + if err != nil { + return 0, err + } + return p.I + 1, nil +} + +func TestCallWithRawParams(t *testing.T) { + // setup server + + rpcServer := NewServer() + rpcServer.Register("Raw", &RawParamHandler{}) + + // httptest stuff + testServ := httptest.NewServer(rpcServer) + defer testServ.Close() + + // setup client + var client struct { + Call func(ctx context.Context, ps RawParams) (int, error) + } + closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Raw", []interface{}{ + &client, + }, nil) + require.NoError(t, err) + + // do the call! + + // this will block if it's not sent as a notification + n, err := client.Call(context.Background(), []byte(`{"I": 1}`)) + require.NoError(t, err) + require.Equal(t, 2, n) + + closer() +} type RevCallTestServerHandler struct { } diff --git a/websocket.go b/websocket.go index 6722bea..5f8bfa9 100644 --- a/websocket.go +++ b/websocket.go @@ -30,8 +30,8 @@ type frame struct { Meta map[string]string `json:"meta,omitempty"` // request - Method string `json:"method,omitempty"` - Params []param `json:"params,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` // response Result json.RawMessage `json:"result,omitempty"` @@ -233,11 +233,17 @@ func (c *wsConn) handleOutChans() { cases = cases[:n] caseToID = caseToID[:n-internal] + rp, err := json.Marshal([]param{{v: reflect.ValueOf(id)}}) + if err != nil { + log.Error(err) + continue + } + if err := c.sendRequest(request{ Jsonrpc: "2.0", ID: nil, // notification Method: chClose, - Params: []param{{v: reflect.ValueOf(id)}}, + Params: rp, }); err != nil { log.Warnf("closed out channel sendRequest failed: %s", err) } @@ -245,11 +251,17 @@ func (c *wsConn) handleOutChans() { } // forward message + rp, err := json.Marshal([]param{{v: reflect.ValueOf(caseToID[chosen-internal])}, {v: val}}) + if err != nil { + log.Errorw("marshaling params for sendRequest failed", "err", err) + continue + } + if err := c.sendRequest(request{ Jsonrpc: "2.0", ID: nil, // notification Method: chValue, - Params: []param{{v: reflect.ValueOf(caseToID[chosen-internal])}, {v: val}}, + Params: rp, }); err != nil { log.Warnf("sendRequest failed: %s", err) return @@ -291,10 +303,16 @@ func (c *wsConn) handleChanOut(ch reflect.Value, req interface{}) error { func (c *wsConn) handleCtxAsync(actx context.Context, id interface{}) { <-actx.Done() + rp, err := json.Marshal([]param{{v: reflect.ValueOf(id)}}) + if err != nil { + log.Errorw("marshaling params for sendRequest failed", "err", err) + return + } + if err := c.sendRequest(request{ Jsonrpc: "2.0", Method: wsCancel, - Params: []param{{v: reflect.ValueOf(id)}}, + Params: rp, }); err != nil { log.Warnw("failed to send request", "method", wsCancel, "id", id, "error", err.Error()) } @@ -306,8 +324,14 @@ func (c *wsConn) cancelCtx(req frame) { log.Warnf("%s call with ID set, won't respond", wsCancel) } + var params []param + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) + return + } + var id interface{} - if err := json.Unmarshal(req.Params[0].data, &id); err != nil { + if err := json.Unmarshal(params[0].data, &id); err != nil { log.Error("handle me:", err) return } @@ -326,8 +350,14 @@ func (c *wsConn) cancelCtx(req frame) { // // func (c *wsConn) handleChanMessage(frame frame) { + var params []param + if err := json.Unmarshal(frame.Params, ¶ms); err != nil { + log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) + return + } + var chid uint64 - if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil { + if err := json.Unmarshal(params[0].data, &chid); err != nil { log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) return } @@ -338,12 +368,18 @@ func (c *wsConn) handleChanMessage(frame frame) { return } - hnd(frame.Params[1].data, true) + hnd(params[1].data, true) } func (c *wsConn) handleChanClose(frame frame) { + var params []param + if err := json.Unmarshal(frame.Params, ¶ms); err != nil { + log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) + return + } + var chid uint64 - if err := json.Unmarshal(frame.Params[0].data, &chid); err != nil { + if err := json.Unmarshal(params[0].data, &chid); err != nil { log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) return }