Skip to content

Commit

Permalink
Support sending notification calls
Browse files Browse the repository at this point in the history
  • Loading branch information
magik6k committed Jan 23, 2023
1 parent e4a5f2d commit 44a9f01
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 27 deletions.
62 changes: 39 additions & 23 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,18 @@ func httpClient(ctx context.Context, addr string, namespace string, outs []inter
defer httpResp.Body.Close()

var resp clientResponse
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
return clientResponse{}, xerrors.Errorf("http status %s unmarshaling response: %w", httpResp.Status, err)
}
if cr.req.ID != nil { // non-notification
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
return clientResponse{}, xerrors.Errorf("http status %s unmarshaling response: %w", httpResp.Status, err)
}

if resp.ID, err = normalizeID(resp.ID); err != nil {
return clientResponse{}, xerrors.Errorf("failed to response ID: %w", err)
}
if resp.ID, err = normalizeID(resp.ID); err != nil {
return clientResponse{}, xerrors.Errorf("failed to response ID: %w", err)
}

if resp.ID != cr.req.ID {
return clientResponse{}, xerrors.New("request and response id didn't match")
if resp.ID != cr.req.ID {
return clientResponse{}, xerrors.New("request and response id didn't match")
}
}

return resp, nil
Expand Down Expand Up @@ -220,7 +222,7 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
errors: config.errors,
}

requests := c.setup()
requests := c.setupRequestChan()

stop := make(chan struct{})
exiting := make(chan struct{})
Expand Down Expand Up @@ -258,7 +260,7 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
}, nil
}

func (c *client) setup() chan clientRequest {
func (c *client) setupRequestChan() chan clientRequest {
requests := make(chan clientRequest)

c.doRequest = func(ctx context.Context, cr clientRequest) (clientResponse, error) {
Expand Down Expand Up @@ -290,6 +292,7 @@ func (c *client) setup() chan clientRequest {
Method: wsCancel,
Params: []param{{v: reflect.ValueOf(cr.req.ID)}},
},
ready: make(chan clientResponse, 1),
}
select {
case requests <- cancelReq:
Expand Down Expand Up @@ -452,7 +455,8 @@ type rpcFunc struct {
hasCtx int
returnValueIsChannel bool

retry bool
retry bool
notify bool
}

func (fn *rpcFunc) processResponse(resp clientResponse, rval reflect.Value) []reflect.Value {
Expand Down Expand Up @@ -487,7 +491,22 @@ func (fn *rpcFunc) processError(err error) []reflect.Value {
}

func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value) {
var id interface{} = atomic.AddInt64(&fn.client.idCtr, 1)
var id interface{}
if !fn.notify {
id = atomic.AddInt64(&fn.client.idCtr, 1)

// Prepare the ID to send on the wire.
// We track int64 ids as float64 in the inflight map (because that's what
// they'll be decoded to). encoding/json outputs numbers with their minimal
// encoding, avoding the decimal point when possible, i.e. 3 will never get
// converted to 3.0.
var err error
id, err = normalizeID(id)
if err != nil {
return fn.processError(fmt.Errorf("failed to normalize id")) // should probably panic
}
}

params := make([]param, len(args)-fn.hasCtx)
for i, arg := range args[fn.hasCtx:] {
enc, found := fn.client.paramEncoders[arg.Type()]
Expand Down Expand Up @@ -522,16 +541,6 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
retVal, chCtor = fn.client.makeOutChan(ctx, fn.ftyp, fn.valOut)
}

// Prepare the ID to send on the wire.
// We track int64 ids as float64 in the inflight map (because that's what
// they'll be decoded to). encoding/json outputs numbers with their minimal
// encoding, avoding the decimal point when possible, i.e. 3 will never get
// converted to 3.0.
id, err := normalizeID(id)
if err != nil {
return fn.processError(fmt.Errorf("failed to normalize id")) // should probably panic
}

req := request{
Jsonrpc: "2.0",
ID: id,
Expand All @@ -554,6 +563,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
minDelay: methodMinRetryDelay,
}

var err error
var resp clientResponse
// keep retrying if got a forced closed websocket conn and calling method
// has retry annotation
Expand All @@ -563,7 +573,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)
return fn.processError(fmt.Errorf("sendRequest failed: %w", err))
}

if resp.ID != req.ID {
if !fn.notify && resp.ID != req.ID {
return fn.processError(xerrors.New("request and response id didn't match"))
}

Expand Down Expand Up @@ -593,6 +603,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)

const (
ProxyTagRetry = "retry"
ProxyTagNotify = "notify"
ProxyTagRPCMethod = "rpc_method"
)

Expand All @@ -612,9 +623,14 @@ func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) {
ftyp: ftyp,
name: name,
retry: f.Tag.Get(ProxyTagRetry) == "true",
notify: f.Tag.Get(ProxyTagNotify) == "true",
}
fun.valOut, fun.errOut, fun.nout = processFuncOut(ftyp)

if fun.valOut != -1 && fun.notify {
return reflect.Value{}, xerrors.New("notify methods cannot return values")
}

if ftyp.NumIn() > 0 && ftyp.In(0) == contextType {
fun.hasCtx = 1
}
Expand Down
2 changes: 1 addition & 1 deletion options_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func WithReverseClient[RP any](namespace string) ServerOption {
stop := make(chan struct{}) // todo better stop?
cl.exiting = stop

requests := cl.setup()
requests := cl.setupRequestChan()
conn.requests = requests

calls := new(RP)
Expand Down
49 changes: 49 additions & 0 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,55 @@ func TestAliasedCall(t *testing.T) {
closer()
}

type NotifHandler struct {
notified chan struct{}
}

func (h *NotifHandler) Notif() {
close(h.notified)
}

func TestNotif(t *testing.T) {
tc := func(proto string) func(t *testing.T) {
return func(t *testing.T) {
// setup server

nh := &NotifHandler{
notified: make(chan struct{}),
}

rpcServer := NewServer()
rpcServer.Register("Notif", nh)

// httptest stuff
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()

// setup client
var client struct {
Notif func() error `notify:"true"`
}
closer, err := NewMergeClient(context.Background(), proto+"://"+testServ.Listener.Addr().String(), "Notif", []interface{}{
&client,
}, nil)
require.NoError(t, err)

// do the call!

// this will block if it's not sent as a notification
err = client.Notif()
require.NoError(t, err)

<-nh.notified

closer()
}
}

t.Run("ws", tc("ws"))
t.Run("http", tc("http"))
}

// 1. make server call on client **
// 2. make client handle **
// 3. alias on client **
Expand Down
20 changes: 17 additions & 3 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
action = fmt.Sprintf("send-request(%s,%v)", req.req.Method, req.req.ID)

c.writeLk.Lock()
if req.req.ID != nil {
if req.req.ID != nil { // non-notification
if c.incomingErr != nil { // No conn?, immediate fail
req.ready <- clientResponse{
Jsonrpc: "2.0",
Expand All @@ -671,9 +671,23 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
c.inflight[req.req.ID] = req
}
c.writeLk.Unlock()
if err := c.sendRequest(req.req); err != nil {
log.Errorf("sendReqest failed (Handle me): %s", err)
serr := c.sendRequest(req.req)
if serr != nil {
log.Errorf("sendReqest failed (Handle me): %s", serr)
}
if req.req.ID == nil { // notification, return immediately
resp := clientResponse{
Jsonrpc: "2.0",
}
if serr != nil {
resp.Error = &respError{
Code: eTempWSError,
Message: fmt.Sprintf("sendRequest: %s", serr),
}
}
req.ready <- resp
}

case <-c.pongs:
action = "pong"

Expand Down

0 comments on commit 44a9f01

Please sign in to comment.