Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpcserver #28

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 272 additions & 0 deletions rpcserver/jsonrpc_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
// Package rpcserver allows exposing functions like:
// func Foo(context, int) (int, error)
// as a JSON RPC methods
//
// This implementation is similar to the one in go-ethereum, but the idea is to eventually replace it as a default
// JSON RPC server implementation in Flasbhots projects and for this we need to reimplement some of the quirks of existing API.
package rpcserver

import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"

"github.com/ethereum/go-ethereum/common"
"github.com/flashbots/go-utils/signature"
)

var (
// this are the only errors that are returned as http errors with http error codes
errMethodNotAllowed = "only POST method is allowded"
errWrongContentType = "header Content-Type must be application/json"
errMarshalResponse = "failed to marshal response"

CodeParseError = -32700
CodeInvalidRequest = -32600
CodeMethodNotFound = -32601
CodeInvalidParams = -32602
CodeInternalError = -32603
CodeCustomError = -32000

DefaultMaxRequestBodySizeBytes = 30 * 1024 * 1024 // 30mb
)

const (
maxOriginIDLength = 255
)

type (
highPriorityKey struct{}
signerKey struct{}
originKey struct{}
)

type jsonRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Method string `json:"method"`
Params []json.RawMessage `json:"params"`
}

type jsonRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID any `json:"id"`
Result *json.RawMessage `json:"result,omitempty"`
Error *jsonRPCError `json:"error,omitempty"`
}

type jsonRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
Data *any `json:"data,omitempty"`
}

type JSONRPCHandler struct {
JSONRPCHandlerOpts
methods map[string]methodHandler
}

type Methods map[string]any

type JSONRPCHandlerOpts struct {
// Logger, can be nil
Log *slog.Logger
// Max size of the request payload
MaxRequestBodySizeBytes int64
// If true payload signature from X-Flashbots-Signature will be verified
// Result can be extracted from the context using GetSigner
VerifyRequestSignatureFromHeader bool
// If true signer from X-Flashbots-Signature will be extracted without verifying signature
// Result can be extracted from the context using GetSigner
ExtractUnverifiedRequestSignatureFromHeader bool
// If true high_prio header value will be extracted (true or false)
// Result can be extracted from the context using GetHighPriority
ExtractPriorityFromHeader bool
// If true extract value from x-flashbots-origin header
// Result can be extracted from the context using GetOrigin
ExtractOriginFromHeader bool
}

// NewJSONRPCHandler creates JSONRPC http.Handler from the map that maps method names to method functions
// each method function must:
// - have context as a first argument
// - return error as a last argument
// - have argument types that can be unmarshalled from JSON
// - have return types that can be marshalled to JSON
func NewJSONRPCHandler(methods Methods, opts JSONRPCHandlerOpts) (*JSONRPCHandler, error) {
if opts.MaxRequestBodySizeBytes == 0 {
opts.MaxRequestBodySizeBytes = int64(DefaultMaxRequestBodySizeBytes)
}

m := make(map[string]methodHandler)
for name, fn := range methods {
method, err := getMethodTypes(fn)
if err != nil {
return nil, err
}
m[name] = method
}
return &JSONRPCHandler{
JSONRPCHandlerOpts: opts,
methods: m,
}, nil
}

func (h *JSONRPCHandler) writeJSONRPCResponse(w http.ResponseWriter, response jsonRPCResponse) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
if h.Log != nil {
h.Log.Error("failed to marshall response", slog.Any("error", err))
}
http.Error(w, errMarshalResponse, http.StatusInternalServerError)
return
}
}

func (h *JSONRPCHandler) writeJSONRPCError(w http.ResponseWriter, id any, code int, msg string) {
res := jsonRPCResponse{
JSONRPC: "2.0",
ID: id,
Result: nil,
Error: &jsonRPCError{
Code: code,
Message: msg,
Data: nil,
},
}
h.writeJSONRPCResponse(w, res)
}

func (h *JSONRPCHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

if r.Method != http.MethodPost {
http.Error(w, errMethodNotAllowed, http.StatusMethodNotAllowed)
return
}

if r.Header.Get("Content-Type") != "application/json" {
http.Error(w, errWrongContentType, http.StatusUnsupportedMediaType)
return
}

r.Body = http.MaxBytesReader(w, r.Body, h.MaxRequestBodySizeBytes)
body, err := io.ReadAll(r.Body)
if err != nil {
msg := fmt.Sprintf("request body is too big, max size: %d", h.MaxRequestBodySizeBytes)
h.writeJSONRPCError(w, nil, CodeInvalidRequest, msg)
return
}

if h.VerifyRequestSignatureFromHeader {
signatureHeader := r.Header.Get("x-flashbots-signature")
signer, err := signature.Verify(signatureHeader, body)
if err != nil {
h.writeJSONRPCError(w, nil, CodeInvalidRequest, err.Error())
return
}
ctx = context.WithValue(ctx, signerKey{}, signer)
}

// read request
var req jsonRPCRequest
if err := json.Unmarshal(body, &req); err != nil {
h.writeJSONRPCError(w, nil, CodeParseError, err.Error())
return
}

if req.JSONRPC != "2.0" {
h.writeJSONRPCError(w, req.ID, CodeParseError, "invalid jsonrpc version")
return
}
if req.ID != nil {
// id must be string or number
switch req.ID.(type) {
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
default:
h.writeJSONRPCError(w, req.ID, CodeParseError, "invalid id type")
}
}

if h.ExtractPriorityFromHeader {
highPriority := r.Header.Get("high_prio") == "true"
ctx = context.WithValue(ctx, highPriorityKey{}, highPriority)
}

if h.ExtractUnverifiedRequestSignatureFromHeader {
signature := r.Header.Get("x-flashbots-signature")
if split := strings.Split(signature, ":"); len(split) > 0 {
signer := common.HexToAddress(split[0])
ctx = context.WithValue(ctx, signerKey{}, signer)
}
}

if h.ExtractOriginFromHeader {
origin := r.Header.Get("x-flashbots-origin")
if origin != "" {
if len(origin) > maxOriginIDLength {
h.writeJSONRPCError(w, req.ID, CodeInvalidRequest, "x-flashbots-origin header is too long")
return
}
ctx = context.WithValue(ctx, originKey{}, origin)
}
}

// get method
method, ok := h.methods[req.Method]
if !ok {
h.writeJSONRPCError(w, req.ID, CodeMethodNotFound, "method not found")
return
}

// call method
result, err := method.call(ctx, req.Params)
if err != nil {
h.writeJSONRPCError(w, req.ID, CodeCustomError, err.Error())
return
}

marshaledResult, err := json.Marshal(result)
if err != nil {
h.writeJSONRPCError(w, req.ID, CodeInternalError, err.Error())
return
}

// write response
rawMessageResult := json.RawMessage(marshaledResult)
res := jsonRPCResponse{
JSONRPC: "2.0",
ID: req.ID,
Result: &rawMessageResult,
Error: nil,
}
h.writeJSONRPCResponse(w, res)
}

func GetHighPriority(ctx context.Context) bool {
value, ok := ctx.Value(highPriorityKey{}).(bool)
if !ok {
return false
}
return value
}

func GetSigner(ctx context.Context) common.Address {
value, ok := ctx.Value(signerKey{}).(common.Address)
if !ok {
return common.Address{}
}
return value
}

func GetOrigin(ctx context.Context) string {
value, ok := ctx.Value(originKey{}).(string)
if !ok {
return ""
}
return value
}
122 changes: 122 additions & 0 deletions rpcserver/jsonrpc_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package rpcserver

import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/flashbots/go-utils/rpcclient"
"github.com/flashbots/go-utils/signature"
"github.com/stretchr/testify/require"
)

func testHandler(opts JSONRPCHandlerOpts) *JSONRPCHandler {
var (
errorArg = -1
errorOut = errors.New("custom error") //nolint:goerr113
)
handlerMethod := func(ctx context.Context, arg1 int) (dummyStruct, error) {
if arg1 == errorArg {
return dummyStruct{}, errorOut
}
return dummyStruct{arg1}, nil
}

handler, err := NewJSONRPCHandler(map[string]interface{}{
"function": handlerMethod,
}, opts)
if err != nil {
panic(err)
}
return handler
}

func TestHandler_ServeHTTP(t *testing.T) {
handler := testHandler(JSONRPCHandlerOpts{})

testCases := map[string]struct {
requestBody string
expectedResponse string
}{
"success": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"result":{"field":1}}`,
},
"error": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[-1]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"custom error"}}`,
},
"invalid json": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1]`,
expectedResponse: `{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"unexpected end of JSON input"}}`,
},
"method not found": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"not_found","params":[1]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"method not found"}}`,
},
"invalid params": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":[1,2]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"too much arguments"}}`, // TODO: return correct code here
},
"invalid params type": {
requestBody: `{"jsonrpc":"2.0","id":1,"method":"function","params":["1"]}`,
expectedResponse: `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"json: cannot unmarshal string into Go value of type int"}}`,
},
}

for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
body := bytes.NewReader([]byte(testCase.requestBody))
request, err := http.NewRequest(http.MethodPost, "/", body)
require.NoError(t, err)
request.Header.Add("Content-Type", "application/json")

rr := httptest.NewRecorder()

handler.ServeHTTP(rr, request)
require.Equal(t, http.StatusOK, rr.Code)

require.JSONEq(t, testCase.expectedResponse, rr.Body.String())
})
}
}

func TestJSONRPCServerWithClient(t *testing.T) {
handler := testHandler(JSONRPCHandlerOpts{})
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

client := rpcclient.NewClient(httpServer.URL)

var resp dummyStruct
err := client.CallFor(context.Background(), &resp, "function", 123)
require.NoError(t, err)
require.Equal(t, 123, resp.Field)
}

func TestJSONRPCServerWithSignatureWithClient(t *testing.T) {
handler := testHandler(JSONRPCHandlerOpts{VerifyRequestSignatureFromHeader: true})
httpServer := httptest.NewServer(handler)
defer httpServer.Close()

// first we do request without signature
client := rpcclient.NewClient(httpServer.URL)
resp, err := client.Call(context.Background(), "function", 123)
require.NoError(t, err)
require.Equal(t, "no signature provided", resp.Error.Message)

// call with signature
signer, err := signature.NewRandomSigner()
require.NoError(t, err)
client = rpcclient.NewClientWithOpts(httpServer.URL, &rpcclient.RPCClientOpts{
Signer: signer,
})

var structResp dummyStruct
err = client.CallFor(context.Background(), &structResp, "function", 123)
require.NoError(t, err)
require.Equal(t, 123, structResp.Field)
}
Loading
Loading