diff --git a/rpcserver/jsonrpc_server.go b/rpcserver/jsonrpc_server.go new file mode 100644 index 0000000..be4207c --- /dev/null +++ b/rpcserver/jsonrpc_server.go @@ -0,0 +1,272 @@ +// Package rpcserver allows exposing functions like: +// func Foo(context, int) (int, error) +// as a JSON RPC methods +// +// This implementation is similar to the one in go-ethereum, but the idea is to eventually replace it as a default +// JSON RPC server implementation in Flasbhots projects and for this we need to reimplement some of the quirks of existing API. +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + + "github.com/ethereum/go-ethereum/common" + "github.com/flashbots/go-utils/signature" +) + +var ( + // this are the only errors that are returned as http errors with http error codes + errMethodNotAllowed = "only POST method is allowded" + errWrongContentType = "header Content-Type must be application/json" + errMarshalResponse = "failed to marshal response" + + CodeParseError = -32700 + CodeInvalidRequest = -32600 + CodeMethodNotFound = -32601 + CodeInvalidParams = -32602 + CodeInternalError = -32603 + CodeCustomError = -32000 + + DefaultMaxRequestBodySizeBytes = 30 * 1024 * 1024 // 30mb +) + +const ( + maxOriginIDLength = 255 +) + +type ( + highPriorityKey struct{} + signerKey struct{} + originKey struct{} +) + +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Method string `json:"method"` + Params []json.RawMessage `json:"params"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Result *json.RawMessage `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` +} + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data *any `json:"data,omitempty"` +} + +type JSONRPCHandler struct { + JSONRPCHandlerOpts + methods map[string]methodHandler +} + +type Methods map[string]any + +type JSONRPCHandlerOpts struct { + // Logger, can be nil + Log *slog.Logger + // Max size of the request payload + MaxRequestBodySizeBytes int64 + // If true payload signature from X-Flashbots-Signature will be verified + // Result can be extracted from the context using GetSigner + VerifyRequestSignatureFromHeader bool + // If true signer from X-Flashbots-Signature will be extracted without verifying signature + // Result can be extracted from the context using GetSigner + ExtractUnverifiedRequestSignatureFromHeader bool + // If true high_prio header value will be extracted (true or false) + // Result can be extracted from the context using GetHighPriority + ExtractPriorityFromHeader bool + // If true extract value from x-flashbots-origin header + // Result can be extracted from the context using GetOrigin + ExtractOriginFromHeader bool +} + +// NewJSONRPCHandler creates JSONRPC http.Handler from the map that maps method names to method functions +// each method function must: +// - have context as a first argument +// - return error as a last argument +// - have argument types that can be unmarshalled from JSON +// - have return types that can be marshalled to JSON +func NewJSONRPCHandler(methods Methods, opts JSONRPCHandlerOpts) (*JSONRPCHandler, error) { + if opts.MaxRequestBodySizeBytes == 0 { + opts.MaxRequestBodySizeBytes = int64(DefaultMaxRequestBodySizeBytes) + } + + m := make(map[string]methodHandler) + for name, fn := range methods { + method, err := getMethodTypes(fn) + if err != nil { + return nil, err + } + m[name] = method + } + return &JSONRPCHandler{ + JSONRPCHandlerOpts: opts, + methods: m, + }, nil +} + +func (h *JSONRPCHandler) writeJSONRPCResponse(w http.ResponseWriter, response jsonRPCResponse) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + if h.Log != nil { + h.Log.Error("failed to marshall response", slog.Any("error", err)) + } + http.Error(w, errMarshalResponse, http.StatusInternalServerError) + return + } +} + +func (h *JSONRPCHandler) writeJSONRPCError(w http.ResponseWriter, id any, code int, msg string) { + res := jsonRPCResponse{ + JSONRPC: "2.0", + ID: id, + Result: nil, + Error: &jsonRPCError{ + Code: code, + Message: msg, + Data: nil, + }, + } + h.writeJSONRPCResponse(w, res) +} + +func (h *JSONRPCHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if r.Method != http.MethodPost { + http.Error(w, errMethodNotAllowed, http.StatusMethodNotAllowed) + return + } + + if r.Header.Get("Content-Type") != "application/json" { + http.Error(w, errWrongContentType, http.StatusUnsupportedMediaType) + return + } + + r.Body = http.MaxBytesReader(w, r.Body, h.MaxRequestBodySizeBytes) + body, err := io.ReadAll(r.Body) + if err != nil { + msg := fmt.Sprintf("request body is too big, max size: %d", h.MaxRequestBodySizeBytes) + h.writeJSONRPCError(w, nil, CodeInvalidRequest, msg) + return + } + + if h.VerifyRequestSignatureFromHeader { + signatureHeader := r.Header.Get("x-flashbots-signature") + signer, err := signature.Verify(signatureHeader, body) + if err != nil { + h.writeJSONRPCError(w, nil, CodeInvalidRequest, err.Error()) + return + } + ctx = context.WithValue(ctx, signerKey{}, signer) + } + + // read request + var req jsonRPCRequest + if err := json.Unmarshal(body, &req); err != nil { + h.writeJSONRPCError(w, nil, CodeParseError, err.Error()) + return + } + + if req.JSONRPC != "2.0" { + h.writeJSONRPCError(w, req.ID, CodeParseError, "invalid jsonrpc version") + return + } + if req.ID != nil { + // id must be string or number + switch req.ID.(type) { + case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + default: + h.writeJSONRPCError(w, req.ID, CodeParseError, "invalid id type") + } + } + + if h.ExtractPriorityFromHeader { + highPriority := r.Header.Get("high_prio") == "true" + ctx = context.WithValue(ctx, highPriorityKey{}, highPriority) + } + + if h.ExtractUnverifiedRequestSignatureFromHeader { + signature := r.Header.Get("x-flashbots-signature") + if split := strings.Split(signature, ":"); len(split) > 0 { + signer := common.HexToAddress(split[0]) + ctx = context.WithValue(ctx, signerKey{}, signer) + } + } + + if h.ExtractOriginFromHeader { + origin := r.Header.Get("x-flashbots-origin") + if origin != "" { + if len(origin) > maxOriginIDLength { + h.writeJSONRPCError(w, req.ID, CodeInvalidRequest, "x-flashbots-origin header is too long") + return + } + ctx = context.WithValue(ctx, originKey{}, origin) + } + } + + // get method + method, ok := h.methods[req.Method] + if !ok { + h.writeJSONRPCError(w, req.ID, CodeMethodNotFound, "method not found") + return + } + + // call method + result, err := method.call(ctx, req.Params) + if err != nil { + h.writeJSONRPCError(w, req.ID, CodeCustomError, err.Error()) + return + } + + marshaledResult, err := json.Marshal(result) + if err != nil { + h.writeJSONRPCError(w, req.ID, CodeInternalError, err.Error()) + return + } + + // write response + rawMessageResult := json.RawMessage(marshaledResult) + res := jsonRPCResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: &rawMessageResult, + Error: nil, + } + h.writeJSONRPCResponse(w, res) +} + +func GetHighPriority(ctx context.Context) bool { + value, ok := ctx.Value(highPriorityKey{}).(bool) + if !ok { + return false + } + return value +} + +func GetSigner(ctx context.Context) common.Address { + value, ok := ctx.Value(signerKey{}).(common.Address) + if !ok { + return common.Address{} + } + return value +} + +func GetOrigin(ctx context.Context) string { + value, ok := ctx.Value(originKey{}).(string) + if !ok { + return "" + } + return value +} diff --git a/rpcserver/jsonrpc_server_test.go b/rpcserver/jsonrpc_server_test.go new file mode 100644 index 0000000..7f96bc9 --- /dev/null +++ b/rpcserver/jsonrpc_server_test.go @@ -0,0 +1,122 @@ +package rpcserver + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/flashbots/go-utils/rpcclient" + "github.com/flashbots/go-utils/signature" + "github.com/stretchr/testify/require" +) + +func testHandler(opts JSONRPCHandlerOpts) *JSONRPCHandler { + var ( + errorArg = -1 + errorOut = errors.New("custom error") //nolint:goerr113 + ) + handlerMethod := func(ctx context.Context, arg1 int) (dummyStruct, error) { + if arg1 == errorArg { + return dummyStruct{}, errorOut + } + return dummyStruct{arg1}, nil + } + + handler, err := NewJSONRPCHandler(map[string]interface{}{ + "function": handlerMethod, + }, opts) + if err != nil { + panic(err) + } + return handler +} + +func TestHandler_ServeHTTP(t *testing.T) { + handler := testHandler(JSONRPCHandlerOpts{}) + + testCases := map[string]struct { + requestBody string + expectedResponse string + }{ + "success": { + requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1]}`, + expectedResponse: `{"jsonrpc":"2.0","id":1,"result":{"field":1}}`, + }, + "error": { + requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[-1]}`, + expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"custom error"}}`, + }, + "invalid json": { + requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1]`, + expectedResponse: `{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"unexpected end of JSON input"}}`, + }, + "method not found": { + requestBody: `{"jsonrpc":"2.0","id":1,"method":"not_found","params":[1]}`, + expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"method not found"}}`, + }, + "invalid params": { + requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1,2]}`, + expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"too much arguments"}}`, // TODO: return correct code here + }, + "invalid params type": { + requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":["1"]}`, + expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"json: cannot unmarshal string into Go value of type int"}}`, + }, + } + + for name, testCase := range testCases { + t.Run(name, func(t *testing.T) { + body := bytes.NewReader([]byte(testCase.requestBody)) + request, err := http.NewRequest(http.MethodPost, "/", body) + require.NoError(t, err) + request.Header.Add("Content-Type", "application/json") + + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, request) + require.Equal(t, http.StatusOK, rr.Code) + + require.JSONEq(t, testCase.expectedResponse, rr.Body.String()) + }) + } +} + +func TestJSONRPCServerWithClient(t *testing.T) { + handler := testHandler(JSONRPCHandlerOpts{}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + client := rpcclient.NewClient(httpServer.URL) + + var resp dummyStruct + err := client.CallFor(context.Background(), &resp, "function", 123) + require.NoError(t, err) + require.Equal(t, 123, resp.Field) +} + +func TestJSONRPCServerWithSignatureWithClient(t *testing.T) { + handler := testHandler(JSONRPCHandlerOpts{VerifyRequestSignatureFromHeader: true}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // first we do request without signature + client := rpcclient.NewClient(httpServer.URL) + resp, err := client.Call(context.Background(), "function", 123) + require.NoError(t, err) + require.Equal(t, "no signature provided", resp.Error.Message) + + // call with signature + signer, err := signature.NewRandomSigner() + require.NoError(t, err) + client = rpcclient.NewClientWithOpts(httpServer.URL, &rpcclient.RPCClientOpts{ + Signer: signer, + }) + + var structResp dummyStruct + err = client.CallFor(context.Background(), &structResp, "function", 123) + require.NoError(t, err) + require.Equal(t, 123, structResp.Field) +} diff --git a/rpcserver/reflect.go b/rpcserver/reflect.go new file mode 100644 index 0000000..fddd653 --- /dev/null +++ b/rpcserver/reflect.go @@ -0,0 +1,104 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "reflect" +) + +var ( + ErrNotFunction = errors.New("not a function") + ErrMustReturnError = errors.New("function must return error as a last return value") + ErrMustHaveContext = errors.New("function must have context.Context as a first argument") + ErrTooManyReturnValues = errors.New("too many return values") + + ErrTooMuchArguments = errors.New("too much arguments") +) + +type methodHandler struct { + in []reflect.Type + out []reflect.Type + fn any +} + +func getMethodTypes(fn interface{}) (methodHandler, error) { + fnType := reflect.TypeOf(fn) + if fnType.Kind() != reflect.Func { + return methodHandler{}, ErrNotFunction + } + numIn := fnType.NumIn() + in := make([]reflect.Type, numIn) + for i := 0; i < numIn; i++ { + in[i] = fnType.In(i) + } + // first input argument must be context.Context + if numIn == 0 || in[0] != reflect.TypeOf((*context.Context)(nil)).Elem() { + return methodHandler{}, ErrMustHaveContext + } + + numOut := fnType.NumOut() + out := make([]reflect.Type, numOut) + for i := 0; i < numOut; i++ { + out[i] = fnType.Out(i) + } + + // function must contain error as a last return value + if numOut == 0 || !out[numOut-1].Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return methodHandler{}, ErrMustReturnError + } + + // function can return only one value + if numOut > 2 { + return methodHandler{}, ErrTooManyReturnValues + } + + return methodHandler{in, out, fn}, nil +} + +func (h methodHandler) call(ctx context.Context, params []json.RawMessage) (any, error) { + args, err := extractArgumentsFromJSONparamsArray(h.in[1:], params) + if err != nil { + return nil, err + } + + // prepend context.Context + args = append([]reflect.Value{reflect.ValueOf(ctx)}, args...) + + // call function + results := reflect.ValueOf(h.fn).Call(args) + + // check error + var outError error + if !results[len(results)-1].IsNil() { + errVal, ok := results[len(results)-1].Interface().(error) + if !ok { + return nil, ErrMustReturnError + } + outError = errVal + } + + if len(results) == 1 { + return nil, outError + } else { + return results[0].Interface(), outError + } +} + +func extractArgumentsFromJSONparamsArray(in []reflect.Type, params []json.RawMessage) ([]reflect.Value, error) { + if len(params) > len(in) { + return nil, ErrTooMuchArguments + } + + args := make([]reflect.Value, len(in)) + for i, argType := range in { + arg := reflect.New(argType) + if i < len(params) { + if err := json.Unmarshal(params[i], arg.Interface()); err != nil { + return nil, err + } + } + args[i] = arg.Elem() + } + return args, nil +} diff --git a/rpcserver/reflect_test.go b/rpcserver/reflect_test.go new file mode 100644 index 0000000..5acdc78 --- /dev/null +++ b/rpcserver/reflect_test.go @@ -0,0 +1,223 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +type ctxKey string + +func rawParams(raw string) []json.RawMessage { + var params []json.RawMessage + err := json.Unmarshal([]byte(raw), ¶ms) + if err != nil { + panic(err) + } + return params +} + +func TestGetMethodTypes(t *testing.T) { + funcWithTypes := func(ctx context.Context, arg1 int, arg2 float32) error { + return nil + } + methodTypes, err := getMethodTypes(funcWithTypes) + require.NoError(t, err) + require.Equal(t, 3, len(methodTypes.in)) + require.Equal(t, 1, len(methodTypes.out)) + + funcWithoutArgs := func(ctx context.Context) error { + return nil + } + methodTypes, err = getMethodTypes(funcWithoutArgs) + require.NoError(t, err) + + funcWithouCtx := func(arg1 int, arg2 float32) error { + return nil + } + methodTypes, err = getMethodTypes(funcWithouCtx) + require.ErrorIs(t, err, ErrMustHaveContext) + + funcWithouError := func(ctx context.Context, arg1 int, arg2 float32) (int, float32) { + return 0, 0 + } + methodTypes, err = getMethodTypes(funcWithouError) + require.ErrorIs(t, err, ErrMustReturnError) + + funcWithTooManyReturnValues := func(ctx context.Context, arg1 int, arg2 float32) (int, float32, error) { + return 0, 0, nil + } + methodTypes, err = getMethodTypes(funcWithTooManyReturnValues) + require.ErrorIs(t, err, ErrTooManyReturnValues) +} + +type dummyStruct struct { + Field int `json:"field"` +} + +func TestExtractArgumentsFromJSON(t *testing.T) { + funcWithTypes := func(context.Context, int, float32, []int, dummyStruct) error { + return nil + } + methodTypes, err := getMethodTypes(funcWithTypes) + require.NoError(t, err) + + jsonArgs := rawParams(`[1, 2.0, [2, 3, 5], {"field": 11}]`) + args, err := extractArgumentsFromJSONparamsArray(methodTypes.in[1:], jsonArgs) + require.NoError(t, err) + require.Equal(t, 4, len(args)) + require.Equal(t, int(1), args[0].Interface()) + require.Equal(t, float32(2.0), args[1].Interface()) + require.Equal(t, []int{2, 3, 5}, args[2].Interface()) + require.Equal(t, dummyStruct{Field: 11}, args[3].Interface()) + + funcWithoutArgs := func(context.Context) error { + return nil + } + methodTypes, err = getMethodTypes(funcWithoutArgs) + require.NoError(t, err) + jsonArgs = rawParams(`[]`) + args, err = extractArgumentsFromJSONparamsArray(methodTypes.in[1:], jsonArgs) + require.NoError(t, err) + require.Equal(t, 0, len(args)) +} + +func TestCall_old(t *testing.T) { + var ( + errorArg = 0 + errorOut = errors.New("function error") //nolint:goerr113 + ) + funcWithTypes := func(ctx context.Context, arg int) (dummyStruct, error) { + value := ctx.Value(ctxKey("key")).(string) //nolint:forcetypeassert + require.Equal(t, "value", value) + + if arg == errorArg { + return dummyStruct{}, errorOut + } + return dummyStruct{arg}, nil + } + methodTypes, err := getMethodTypes(funcWithTypes) + require.NoError(t, err) + + ctx := context.WithValue(context.Background(), ctxKey("key"), "value") + + jsonArgs := rawParams(`[1]`) + result, err := methodTypes.call(ctx, jsonArgs) + require.NoError(t, err) + require.Equal(t, dummyStruct{1}, result) + + jsonArgs = rawParams(fmt.Sprintf(`[%d]`, errorArg)) + result, err = methodTypes.call(ctx, jsonArgs) + require.ErrorIs(t, err, errorOut) + require.Equal(t, dummyStruct{}, result) +} + +func TestCall(t *testing.T) { + // for testing error return + var ( + errorArg = 0 + errorOut = errors.New("function error") //nolint:goerr113 + ) + functionWithTypes := func(ctx context.Context, arg int) (dummyStruct, error) { + // test context + value := ctx.Value(ctxKey("key")).(string) //nolint:forcetypeassert + require.Equal(t, "value", value) + + if arg == errorArg { + return dummyStruct{}, errorOut + } + return dummyStruct{arg}, nil + } + functionNoArgs := func(ctx context.Context) (dummyStruct, error) { + // test context + value := ctx.Value(ctxKey("key")).(string) //nolint:forcetypeassert + require.Equal(t, "value", value) + + return dummyStruct{1}, nil + } + functionNoArgsError := func(ctx context.Context) (dummyStruct, error) { + // test context + value := ctx.Value(ctxKey("key")).(string) //nolint:forcetypeassert + require.Equal(t, "value", value) + + return dummyStruct{}, errorOut + } + functionNoReturn := func(ctx context.Context, arg int) error { + // test context + value := ctx.Value(ctxKey("key")).(string) //nolint:forcetypeassert + require.Equal(t, "value", value) + return nil + } + functonNoReturnError := func(ctx context.Context, arg int) error { + // test context + value := ctx.Value(ctxKey("key")).(string) //nolint:forcetypeassert + require.Equal(t, "value", value) + + return errorOut + } + + testCases := map[string]struct { + function interface{} + args string + expectedValue interface{} + expectedError error + }{ + "functionWithTypes": { + function: functionWithTypes, + args: `[1]`, + expectedValue: dummyStruct{1}, + expectedError: nil, + }, + "functionWithTypesError": { + function: functionWithTypes, + args: fmt.Sprintf(`[%d]`, errorArg), + expectedValue: dummyStruct{}, + expectedError: errorOut, + }, + "functionNoArgs": { + function: functionNoArgs, + args: `[]`, + expectedValue: dummyStruct{1}, + expectedError: nil, + }, + "functionNoArgsError": { + function: functionNoArgsError, + args: `[]`, + expectedValue: dummyStruct{}, + expectedError: errorOut, + }, + "functionNoReturn": { + function: functionNoReturn, + args: `[1]`, + expectedValue: nil, + expectedError: nil, + }, + "functionNoReturnError": { + function: functonNoReturnError, + args: `[1]`, + expectedValue: nil, + expectedError: errorOut, + }, + } + + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + methodTypes, err := getMethodTypes(testCase.function) + require.NoError(t, err) + + ctx := context.WithValue(context.Background(), ctxKey("key"), "value") + + result, err := methodTypes.call(ctx, rawParams(testCase.args)) + if testCase.expectedError == nil { + require.NoError(t, err) + } else { + require.ErrorIs(t, err, testCase.expectedError) + } + require.Equal(t, testCase.expectedValue, result) + }) + } +}