Skip to content

Commit

Permalink
rpcserver
Browse files Browse the repository at this point in the history
  • Loading branch information
dvush committed Oct 25, 2024
1 parent 74dc459 commit c8733cf
Show file tree
Hide file tree
Showing 4 changed files with 721 additions and 0 deletions.
272 changes: 272 additions & 0 deletions rpcserver/jsonrpc_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
// Package rpcserver allows exposing functions like:
// func Foo(context, int) (int, error)
// as a JSON RPC methods
//
// This implementation is similar to the one in go-ethereum, but the idea is to eventually replace it as a default
// JSON RPC server implementation in Flasbhots projects and for this we need to reimplement some of the quirks of existing API.
package rpcserver

import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"

"github.com/ethereum/go-ethereum/common"
"github.com/flashbots/go-utils/signature"
)

var (
// this are the only errors that are returned as http errors with http error codes
errMethodNotAllowed = "only POST method is allowded"
errWrongContentType = "header Content-Type must be application/json"
errMarshalResponse = "failed to marshal response"

CodeParseError = -32700
CodeInvalidRequest = -32600
CodeMethodNotFound = -32601
CodeInvalidParams = -32602
CodeInternalError = -32603
CodeCustomError = -32000

DefaultMaxRequestBodySizeBytes = 30 * 1024 * 1024 // 30mb
)

const (
maxOriginIDLength = 255
)

type (
highPriorityKey struct{}
signerKey struct{}
originKey struct{}
)

type jsonRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Method string `json:"method"`
Params []json.RawMessage `json:"params"`
}

type jsonRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Result *json.RawMessage `json:"result,omitempty"`
Error *jsonRPCError `json:"error,omitempty"`
}

type jsonRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
Data *any `json:"data,omitempty"`
}

type JSONRPCHandler struct {
JSONRPCHandlerOpts
methods map[string]methodHandler
}

type Methods map[string]any

type JSONRPCHandlerOpts struct {
// Logger, can be nil
Log *slog.Logger
// Max size of the request payload
MaxRequestBodySizeBytes int64
// If true payload signature from X-Flashbots-Signature will be verified
// Result can be extracted from the context using GetSigner
VerifyRequestSignatureFromHeader bool
// If true signer from X-Flashbots-Signature will be extracted without verifying signature
// Result can be extracted from the context using GetSigner
ExtractUnverifiedRequestSignatureFromHeader bool
// If true high_prio header value will be extracted (true or false)
// Result can be extracted from the context using GetHighPriority
ExtractPriorityFromHeader bool
// If true extract value from x-flashbots-origin header
// Result can be extracted from the context using GetOrigin
ExtractOriginFromHeader bool
}

// NewJSONRPCHandler creates JSONRPC http.Handler from the map that maps method names to method functions
// each method function must:
// - have context as a first argument
// - return error as a last argument
// - have argument types that can be unmarshalled from JSON
// - have return types that can be marshalled to JSON
func NewJSONRPCHandler(methods Methods, opts JSONRPCHandlerOpts) (*JSONRPCHandler, error) {
if opts.MaxRequestBodySizeBytes == 0 {
opts.MaxRequestBodySizeBytes = int64(DefaultMaxRequestBodySizeBytes)
}

m := make(map[string]methodHandler)
for name, fn := range methods {
method, err := getMethodTypes(fn)
if err != nil {
return nil, err
}
m[name] = method
}
return &JSONRPCHandler{
JSONRPCHandlerOpts: opts,
methods: m,
}, nil
}

func (h *JSONRPCHandler) writeJSONRPCResponse(w http.ResponseWriter, response jsonRPCResponse) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
if h.Log != nil {
h.Log.Error("failed to marshall response", slog.Any("error", err))
}
http.Error(w, errMarshalResponse, http.StatusInternalServerError)
return
}
}

func (h *JSONRPCHandler) writeJSONRPCError(w http.ResponseWriter, id any, code int, msg string) {
res := jsonRPCResponse{
JSONRPC: "2.0",
ID: id,
Result: nil,
Error: &jsonRPCError{
Code: code,
Message: msg,
Data: nil,
},
}
h.writeJSONRPCResponse(w, res)
}

func (h *JSONRPCHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

if r.Method != http.MethodPost {
http.Error(w, errMethodNotAllowed, http.StatusMethodNotAllowed)
return
}

if r.Header.Get("Content-Type") != "application/json" {
http.Error(w, errWrongContentType, http.StatusUnsupportedMediaType)
return
}

r.Body = http.MaxBytesReader(w, r.Body, h.MaxRequestBodySizeBytes)
body, err := io.ReadAll(r.Body)
if err != nil {
msg := fmt.Sprintf("request body is too big, max size: %d", h.MaxRequestBodySizeBytes)
h.writeJSONRPCError(w, nil, CodeInvalidRequest, msg)
return
}

if h.VerifyRequestSignatureFromHeader {
signatureHeader := r.Header.Get("x-flashbots-signature")
signer, err := signature.Verify(signatureHeader, body)
if err != nil {
h.writeJSONRPCError(w, nil, CodeInvalidRequest, err.Error())
return
}
ctx = context.WithValue(ctx, signerKey{}, signer)
}

// read request
var req jsonRPCRequest
if err := json.Unmarshal(body, &req); err != nil {
h.writeJSONRPCError(w, nil, CodeParseError, err.Error())
return
}

if req.JSONRPC != "2.0" {
h.writeJSONRPCError(w, req.ID, CodeParseError, "invalid jsonrpc version")
return
}
if req.ID != nil {
// id must be string or number
switch req.ID.(type) {
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
default:
h.writeJSONRPCError(w, req.ID, CodeParseError, "invalid id type")
}
}

if h.ExtractPriorityFromHeader {
highPriority := r.Header.Get("high_prio") == "true"
ctx = context.WithValue(ctx, highPriorityKey{}, highPriority)
}

if h.ExtractUnverifiedRequestSignatureFromHeader {
signature := r.Header.Get("x-flashbots-signature")
if split := strings.Split(signature, ":"); len(split) > 0 {
signer := common.HexToAddress(split[0])
ctx = context.WithValue(ctx, signerKey{}, signer)
}
}

if h.ExtractOriginFromHeader {
origin := r.Header.Get("x-flashbots-origin")
if origin != "" {
if len(origin) > maxOriginIDLength {
h.writeJSONRPCError(w, req.ID, CodeInvalidRequest, "x-flashbots-origin header is too long")
return
}
ctx = context.WithValue(ctx, originKey{}, origin)
}
}

// get method
method, ok := h.methods[req.Method]
if !ok {
h.writeJSONRPCError(w, req.ID, CodeMethodNotFound, "method not found")
return
}

// call method
result, err := method.call(ctx, req.Params)
if err != nil {
h.writeJSONRPCError(w, req.ID, CodeCustomError, err.Error())
return
}

marshaledResult, err := json.Marshal(result)
if err != nil {
h.writeJSONRPCError(w, req.ID, CodeInternalError, err.Error())
return
}

// write response
rawMessageResult := json.RawMessage(marshaledResult)
res := jsonRPCResponse{
JSONRPC: "2.0",
ID: req.ID,
Result: &rawMessageResult,
Error: nil,
}
h.writeJSONRPCResponse(w, res)
}

func GetHighPriority(ctx context.Context) bool {
value, ok := ctx.Value(highPriorityKey{}).(bool)
if !ok {
return false
}
return value
}

func GetSigner(ctx context.Context) common.Address {
value, ok := ctx.Value(signerKey{}).(common.Address)
if !ok {
return common.Address{}
}
return value
}

func GetOrigin(ctx context.Context) string {
value, ok := ctx.Value(originKey{}).(string)
if !ok {
return ""
}
return value
}
122 changes: 122 additions & 0 deletions rpcserver/jsonrpc_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package rpcserver

import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/flashbots/go-utils/rpcclient"
"github.com/flashbots/go-utils/signature"
"github.com/stretchr/testify/require"
)

func testHandler(opts JSONRPCHandlerOpts) *JSONRPCHandler {
var (
errorArg = -1
errorOut = errors.New("custom error") //nolint:goerr113
)
handlerMethod := func(ctx context.Context, arg1 int) (dummyStruct, error) {
if arg1 == errorArg {
return dummyStruct{}, errorOut
}
return dummyStruct{arg1}, nil
}

handler, err := NewJSONRPCHandler(map[string]interface{}{
"function": handlerMethod,
}, opts)
if err != nil {
panic(err)
}
return handler
}

func TestHandler_ServeHTTP(t *testing.T) {
handler := testHandler(JSONRPCHandlerOpts{})

testCases := map[string]struct {
requestBody string
expectedResponse string
}{
"success": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"result":{"field":1}}`,
},
"error": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[-1]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"custom error"}}`,
},
"invalid json": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1]`,
expectedResponse: `{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"unexpected end of JSON input"}}`,
},
"method not found": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"not_found","params":[1]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"method not found"}}`,
},
"invalid params": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1,2]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"too much arguments"}}`, // TODO: return correct code here
},
"invalid params type": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":["1"]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"json: cannot unmarshal string into Go value of type int"}}`,
},
}

for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
body := bytes.NewReader([]byte(testCase.requestBody))
request, err := http.NewRequest(http.MethodPost, "/", body)
require.NoError(t, err)
request.Header.Add("Content-Type", "application/json")

rr := httptest.NewRecorder()

handler.ServeHTTP(rr, request)
require.Equal(t, http.StatusOK, rr.Code)

require.JSONEq(t, testCase.expectedResponse, rr.Body.String())
})
}
}

func TestJSONRPCServerWithClient(t *testing.T) {
handler := testHandler(JSONRPCHandlerOpts{})
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

client := rpcclient.NewClient(httpServer.URL)

var resp dummyStruct
err := client.CallFor(context.Background(), &resp, "function", 123)
require.NoError(t, err)
require.Equal(t, 123, resp.Field)
}

func TestJSONRPCServerWithSignatureWithClient(t *testing.T) {
handler := testHandler(JSONRPCHandlerOpts{VerifyRequestSignatureFromHeader: true})
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

// first we do request without signature
client := rpcclient.NewClient(httpServer.URL)
resp, err := client.Call(context.Background(), "function", 123)
require.NoError(t, err)
require.Equal(t, "no signature provided", resp.Error.Message)

// call with signature
signer, err := signature.NewRandomSigner()
require.NoError(t, err)
client = rpcclient.NewClientWithOpts(httpServer.URL, &rpcclient.RPCClientOpts{
Signer: signer,
})

var structResp dummyStruct
err = client.CallFor(context.Background(), &structResp, "function", 123)
require.NoError(t, err)
require.Equal(t, 123, structResp.Field)
}
Loading

0 comments on commit c8733cf

Please sign in to comment.