Skip to content

Commit

Permalink
feat(baseapp): Add Hybrid Protobuf handlers to MsgServiceRouter (#18071)
Browse files Browse the repository at this point in the history
Co-authored-by: unknown unknown <unknown@unknown>
  • Loading branch information
testinginprod and unknown unknown authored Oct 12, 2023
1 parent b8b32cf commit 5edabb5
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 109 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
* (client) [#17513](https://github.com/cosmos/cosmos-sdk/pull/17513) Allow overwritting `client.toml`. Use `client.CreateClientConfig` in place of `client.ReadFromClientConfig` and provide a custom template and a custom config.
* (x/bank) [#17569](https://github.com/cosmos/cosmos-sdk/pull/17569) Introduce a new message type, `MsgBurn `, to burn coins.
* (server) [#17094](https://github.com/cosmos/cosmos-sdk/pull/17094) Add duration `shutdown-grace` for resource clean up (closing database handles) before exit.
* (baseapp) [#18071](https://github.com/cosmos/cosmos-sdk/pull/18071) Add hybrid handlers to `MsgServiceRouter`.

### Improvements

Expand Down
30 changes: 11 additions & 19 deletions baseapp/grpcrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ import (

abci "github.com/cometbft/cometbft/abci/types"
gogogrpc "github.com/cosmos/gogoproto/grpc"
"github.com/cosmos/gogoproto/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/encoding"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"

"github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat"
Expand All @@ -23,9 +21,9 @@ import (
type GRPCQueryRouter struct {
// routes maps query handlers used in ABCIQuery.
routes map[string]GRPCQueryHandler
// handlerByMessageName maps the request name to the handler. It is a hybrid handler which seamlessly
// hybridHandlers maps the request name to the handler. It is a hybrid handler which seamlessly
// handles both gogo and protov2 messages.
handlerByMessageName map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error
hybridHandlers map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error
// binaryCodec is used to encode/decode binary protobuf messages.
binaryCodec codec.BinaryCodec
// cdc is the gRPC codec used by the router to correctly unmarshal messages.
Expand All @@ -45,8 +43,8 @@ var _ gogogrpc.Server = &GRPCQueryRouter{}
// NewGRPCQueryRouter creates a new GRPCQueryRouter
func NewGRPCQueryRouter() *GRPCQueryRouter {
return &GRPCQueryRouter{
routes: map[string]GRPCQueryHandler{},
handlerByMessageName: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
routes: map[string]GRPCQueryHandler{},
hybridHandlers: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
}
}

Expand Down Expand Up @@ -76,7 +74,7 @@ func (qrt *GRPCQueryRouter) RegisterService(sd *grpc.ServiceDesc, handler interf
if err != nil {
panic(err)
}
err = qrt.registerHandlerByMessageName(sd, method, handler)
err = qrt.registerHybridHandler(sd, method, handler)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -131,27 +129,21 @@ func (qrt *GRPCQueryRouter) registerABCIQueryHandler(sd *grpc.ServiceDesc, metho
return nil
}

func (qrt *GRPCQueryRouter) HandlersByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error {
return qrt.handlerByMessageName[name]
func (qrt *GRPCQueryRouter) HybridHandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error {
return qrt.hybridHandlers[name]
}

func (qrt *GRPCQueryRouter) registerHandlerByMessageName(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
func (qrt *GRPCQueryRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
// extract message name from method descriptor
methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName))
desc, err := proto.HybridResolver.FindDescriptorByName(methodFullName)
inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method)
if err != nil {
return fmt.Errorf("cannot find method descriptor %s", methodFullName)
}
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
if !ok {
return fmt.Errorf("invalid method descriptor %s", methodFullName)
return err
}
inputName := methodDesc.Input().FullName()
methodHandler, err := protocompat.MakeHybridHandler(qrt.binaryCodec, sd, method, handler)
if err != nil {
return err
}
qrt.handlerByMessageName[string(inputName)] = append(qrt.handlerByMessageName[string(inputName)], methodHandler)
qrt.hybridHandlers[string(inputName)] = append(qrt.hybridHandlers[string(inputName)], methodHandler)
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion baseapp/grpcrouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestGRPCQueryRouter(t *testing.T) {
func TestGRPCRouterHybridHandlers(t *testing.T) {
assertRouterBehaviour := func(helper *baseapp.QueryServiceTestHelper) {
// test getting the handler by name
handlers := helper.GRPCQueryRouter.HandlersByRequestName("testpb.EchoRequest")
handlers := helper.GRPCQueryRouter.HybridHandlerByRequestName("testpb.EchoRequest")
require.NotNil(t, handlers)
require.Len(t, handlers, 1)
handler := handlers[0]
Expand Down
14 changes: 14 additions & 0 deletions baseapp/internal/protocompat/protocompat.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,17 @@ func isProtov2(md grpc.MethodDesc) (isV2Type bool, err error) {
_, _ = md.Handler(nil, nil, pullRequestType, doNotExecute)
return
}

// RequestFullNameFromMethodDesc returns the fully-qualified name of the request message of the provided service's method.
func RequestFullNameFromMethodDesc(sd *grpc.ServiceDesc, method grpc.MethodDesc) (protoreflect.FullName, error) {
methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName))
desc, err := gogoproto.HybridResolver.FindDescriptorByName(methodFullName)
if err != nil {
return "", fmt.Errorf("cannot find method descriptor %s", methodFullName)
}
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
if !ok {
return "", fmt.Errorf("invalid method descriptor %s", methodFullName)
}
return methodDesc.Input().FullName(), nil
}
224 changes: 136 additions & 88 deletions baseapp/msg_service_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ import (
gogogrpc "github.com/cosmos/gogoproto/grpc"
"github.com/cosmos/gogoproto/proto"
"google.golang.org/grpc"
"google.golang.org/protobuf/runtime/protoiface"

errorsmod "cosmossdk.io/errors"

"github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
Expand All @@ -27,6 +30,7 @@ type MessageRouter interface {
type MsgServiceRouter struct {
interfaceRegistry codectypes.InterfaceRegistry
routes map[string]MsgServiceHandler
hybridHandlers map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error
circuitBreaker CircuitBreaker
}

Expand All @@ -35,7 +39,8 @@ var _ gogogrpc.Server = &MsgServiceRouter{}
// NewMsgServiceRouter creates a new MsgServiceRouter.
func NewMsgServiceRouter() *MsgServiceRouter {
return &MsgServiceRouter{
routes: map[string]MsgServiceHandler{},
routes: map[string]MsgServiceHandler{},
hybridHandlers: map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
}
}

Expand Down Expand Up @@ -67,113 +72,156 @@ func (msr *MsgServiceRouter) HandlerByTypeURL(typeURL string) MsgServiceHandler
func (msr *MsgServiceRouter) RegisterService(sd *grpc.ServiceDesc, handler interface{}) {
// Adds a top-level query handler based on the gRPC service name.
for _, method := range sd.Methods {
fqMethod := fmt.Sprintf("/%s/%s", sd.ServiceName, method.MethodName)
methodHandler := method.Handler

var requestTypeName string

// NOTE: This is how we pull the concrete request type for each handler for registering in the InterfaceRegistry.
// This approach is maybe a bit hacky, but less hacky than reflecting on the handler object itself.
// We use a no-op interceptor to avoid actually calling into the handler itself.
_, _ = methodHandler(nil, context.Background(), func(i interface{}) error {
msg, ok := i.(sdk.Msg)
if !ok {
// We panic here because there is no other alternative and the app cannot be initialized correctly
// this should only happen if there is a problem with code generation in which case the app won't
// work correctly anyway.
panic(fmt.Errorf("unable to register service method %s: %T does not implement sdk.Msg", fqMethod, i))
}
err := msr.registerMsgServiceHandler(sd, method, handler)
if err != nil {
panic(err)
}
err = msr.registerHybridHandler(sd, method, handler)
if err != nil {
panic(err)
}
}
}

func (msr *MsgServiceRouter) HybridHandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error {
return msr.hybridHandlers[msgName]
}

requestTypeName = sdk.MsgTypeURL(msg)
return nil
}, noopInterceptor)

// Check that the service Msg fully-qualified method name has already
// been registered (via RegisterInterfaces). If the user registers a
// service without registering according service Msg type, there might be
// some unexpected behavior down the road. Since we can't return an error
// (`Server.RegisterService` interface restriction) we panic (at startup).
reqType, err := msr.interfaceRegistry.Resolve(requestTypeName)
if err != nil || reqType == nil {
panic(
fmt.Errorf(
"type_url %s has not been registered yet. "+
"Before calling RegisterService, you must register all interfaces by calling the `RegisterInterfaces` "+
"method on module.BasicManager. Each module should call `msgservice.RegisterMsgServiceDesc` inside its "+
"`RegisterInterfaces` method with the `_Msg_serviceDesc` generated by proto-gen",
requestTypeName,
),
)
func (msr *MsgServiceRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method)
if err != nil {
return err
}
cdc := codec.NewProtoCodec(msr.interfaceRegistry)
hybridHandler, err := protocompat.MakeHybridHandler(cdc, sd, method, handler)
if err != nil {
return err
}
// if circuit breaker is not nil, then we decorate the hybrid handler with the circuit breaker
if msr.circuitBreaker == nil {
msr.hybridHandlers[string(inputName)] = hybridHandler
return nil
}
// decorate the hybrid handler with the circuit breaker
circuitBreakerHybridHandler := func(ctx context.Context, req, resp protoiface.MessageV1) error {
messageName := codectypes.MsgTypeURL(req)
allowed, err := msr.circuitBreaker.IsAllowed(ctx, messageName)
if err != nil {
return err
}
if !allowed {
return fmt.Errorf("circuit breaker disallows execution of message %s", messageName)
}
return hybridHandler(ctx, req, resp)
}
msr.hybridHandlers[string(inputName)] = circuitBreakerHybridHandler
return nil
}

// Check that each service is only registered once. If a service is
// registered more than once, then we should error. Since we can't
// return an error (`Server.RegisterService` interface restriction) we
// panic (at startup).
_, found := msr.routes[requestTypeName]
if found {
panic(
fmt.Errorf(
"msg service %s has already been registered. Please make sure to only register each service once. "+
"This usually means that there are conflicting modules registering the same msg service",
fqMethod,
),
)
func (msr *MsgServiceRouter) registerMsgServiceHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
fqMethod := fmt.Sprintf("/%s/%s", sd.ServiceName, method.MethodName)
methodHandler := method.Handler

var requestTypeName string

// NOTE: This is how we pull the concrete request type for each handler for registering in the InterfaceRegistry.
// This approach is maybe a bit hacky, but less hacky than reflecting on the handler object itself.
// We use a no-op interceptor to avoid actually calling into the handler itself.
_, _ = methodHandler(nil, context.Background(), func(i interface{}) error {
msg, ok := i.(sdk.Msg)
if !ok {
// We panic here because there is no other alternative and the app cannot be initialized correctly
// this should only happen if there is a problem with code generation in which case the app won't
// work correctly anyway.
panic(fmt.Errorf("unable to register service method %s: %T does not implement sdk.Msg", fqMethod, i))
}

msr.routes[requestTypeName] = func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx = ctx.WithEventManager(sdk.NewEventManager())
interceptor := func(goCtx context.Context, _ interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
goCtx = context.WithValue(goCtx, sdk.SdkContextKey, ctx)
return handler(goCtx, msg)
}
requestTypeName = sdk.MsgTypeURL(msg)
return nil
}, noopInterceptor)

// Check that the service Msg fully-qualified method name has already
// been registered (via RegisterInterfaces). If the user registers a
// service without registering according service Msg type, there might be
// some unexpected behavior down the road. Since we can't return an error
// (`Server.RegisterService` interface restriction) we panic (at startup).
reqType, err := msr.interfaceRegistry.Resolve(requestTypeName)
if err != nil || reqType == nil {
return fmt.Errorf(
"type_url %s has not been registered yet. "+
"Before calling RegisterService, you must register all interfaces by calling the `RegisterInterfaces` "+
"method on module.BasicManager. Each module should call `msgservice.RegisterMsgServiceDesc` inside its "+
"`RegisterInterfaces` method with the `_Msg_serviceDesc` generated by proto-gen",
requestTypeName,
)
}

if m, ok := msg.(sdk.HasValidateBasic); ok {
if err := m.ValidateBasic(); err != nil {
return nil, err
}
}
// Check that each service is only registered once. If a service is
// registered more than once, then we should error. Since we can't
// return an error (`Server.RegisterService` interface restriction) we
// panic (at startup).
_, found := msr.routes[requestTypeName]
if found {
return fmt.Errorf(
"msg service %s has already been registered. Please make sure to only register each service once. "+
"This usually means that there are conflicting modules registering the same msg service",
fqMethod,
)
}

if msr.circuitBreaker != nil {
msgURL := sdk.MsgTypeURL(msg)
isAllowed, err := msr.circuitBreaker.IsAllowed(ctx, msgURL)
if err != nil {
return nil, err
}
msr.routes[requestTypeName] = func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx = ctx.WithEventManager(sdk.NewEventManager())
interceptor := func(goCtx context.Context, _ interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
goCtx = context.WithValue(goCtx, sdk.SdkContextKey, ctx)
return handler(goCtx, msg)
}

if !isAllowed {
return nil, fmt.Errorf("circuit breaker disables execution of this message: %s", msgURL)
}
if m, ok := msg.(sdk.HasValidateBasic); ok {
if err := m.ValidateBasic(); err != nil {
return nil, err
}
}

// Call the method handler from the service description with the handler object.
// We don't do any decoding here because the decoding was already done.
res, err := methodHandler(handler, ctx, noopDecoder, interceptor)
if msr.circuitBreaker != nil {
msgURL := sdk.MsgTypeURL(msg)
isAllowed, err := msr.circuitBreaker.IsAllowed(ctx, msgURL)
if err != nil {
return nil, err
}

resMsg, ok := res.(proto.Message)
if !ok {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting proto.Message, got %T", resMsg)
if !isAllowed {
return nil, fmt.Errorf("circuit breaker disables execution of this message: %s", msgURL)
}
}

anyResp, err := codectypes.NewAnyWithValue(resMsg)
if err != nil {
return nil, err
}
// Call the method handler from the service description with the handler object.
// We don't do any decoding here because the decoding was already done.
res, err := methodHandler(handler, ctx, noopDecoder, interceptor)
if err != nil {
return nil, err
}

var events []abci.Event
if evtMgr := ctx.EventManager(); evtMgr != nil {
events = evtMgr.ABCIEvents()
}
resMsg, ok := res.(proto.Message)
if !ok {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting proto.Message, got %T", resMsg)
}

return &sdk.Result{
Events: events,
MsgResponses: []*codectypes.Any{anyResp},
}, nil
anyResp, err := codectypes.NewAnyWithValue(resMsg)
if err != nil {
return nil, err
}

var events []abci.Event
if evtMgr := ctx.EventManager(); evtMgr != nil {
events = evtMgr.ABCIEvents()
}

return &sdk.Result{
Events: events,
MsgResponses: []*codectypes.Any{anyResp},
}, nil
}
return nil
}

// SetInterfaceRegistry sets the interface registry for the router.
Expand Down
Loading

0 comments on commit 5edabb5

Please sign in to comment.