From 5edabb5cb01da453e9a962791ac74b8e37ca82da Mon Sep 17 00:00:00 2001 From: testinginprod <98415576+testinginprod@users.noreply.github.com> Date: Thu, 12 Oct 2023 17:25:28 +0200 Subject: [PATCH] feat(baseapp): Add Hybrid Protobuf handlers to MsgServiceRouter (#18071) Co-authored-by: unknown unknown --- CHANGELOG.md | 1 + baseapp/grpcrouter.go | 30 +-- baseapp/grpcrouter_test.go | 2 +- baseapp/internal/protocompat/protocompat.go | 14 ++ baseapp/msg_service_router.go | 224 ++++++++++++-------- baseapp/msg_service_router_test.go | 35 +++ server/mock/app.go | 50 ++++- 7 files changed, 247 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9938e9e6404a..0a44d372cb75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ * (client) [#17513](https://github.com/cosmos/cosmos-sdk/pull/17513) Allow overwritting `client.toml`. Use `client.CreateClientConfig` in place of `client.ReadFromClientConfig` and provide a custom template and a custom config. * (x/bank) [#17569](https://github.com/cosmos/cosmos-sdk/pull/17569) Introduce a new message type, `MsgBurn `, to burn coins. * (server) [#17094](https://github.com/cosmos/cosmos-sdk/pull/17094) Add duration `shutdown-grace` for resource clean up (closing database handles) before exit. +* (baseapp) [#18071](https://github.com/cosmos/cosmos-sdk/pull/18071) Add hybrid handlers to `MsgServiceRouter`. ### Improvements diff --git a/baseapp/grpcrouter.go b/baseapp/grpcrouter.go index 2c3980ccbbac..a1acc0b9cf2b 100644 --- a/baseapp/grpcrouter.go +++ b/baseapp/grpcrouter.go @@ -6,10 +6,8 @@ import ( abci "github.com/cometbft/cometbft/abci/types" gogogrpc "github.com/cosmos/gogoproto/grpc" - "github.com/cosmos/gogoproto/proto" "google.golang.org/grpc" "google.golang.org/grpc/encoding" - "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/runtime/protoiface" "github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat" @@ -23,9 +21,9 @@ import ( type GRPCQueryRouter struct { // routes maps query handlers used in ABCIQuery. routes map[string]GRPCQueryHandler - // handlerByMessageName maps the request name to the handler. It is a hybrid handler which seamlessly + // hybridHandlers maps the request name to the handler. It is a hybrid handler which seamlessly // handles both gogo and protov2 messages. - handlerByMessageName map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error + hybridHandlers map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error // binaryCodec is used to encode/decode binary protobuf messages. binaryCodec codec.BinaryCodec // cdc is the gRPC codec used by the router to correctly unmarshal messages. @@ -45,8 +43,8 @@ var _ gogogrpc.Server = &GRPCQueryRouter{} // NewGRPCQueryRouter creates a new GRPCQueryRouter func NewGRPCQueryRouter() *GRPCQueryRouter { return &GRPCQueryRouter{ - routes: map[string]GRPCQueryHandler{}, - handlerByMessageName: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{}, + routes: map[string]GRPCQueryHandler{}, + hybridHandlers: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{}, } } @@ -76,7 +74,7 @@ func (qrt *GRPCQueryRouter) RegisterService(sd *grpc.ServiceDesc, handler interf if err != nil { panic(err) } - err = qrt.registerHandlerByMessageName(sd, method, handler) + err = qrt.registerHybridHandler(sd, method, handler) if err != nil { panic(err) } @@ -131,27 +129,21 @@ func (qrt *GRPCQueryRouter) registerABCIQueryHandler(sd *grpc.ServiceDesc, metho return nil } -func (qrt *GRPCQueryRouter) HandlersByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error { - return qrt.handlerByMessageName[name] +func (qrt *GRPCQueryRouter) HybridHandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error { + return qrt.hybridHandlers[name] } -func (qrt *GRPCQueryRouter) registerHandlerByMessageName(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error { +func (qrt *GRPCQueryRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error { // extract message name from method descriptor - methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName)) - desc, err := proto.HybridResolver.FindDescriptorByName(methodFullName) + inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method) if err != nil { - return fmt.Errorf("cannot find method descriptor %s", methodFullName) - } - methodDesc, ok := desc.(protoreflect.MethodDescriptor) - if !ok { - return fmt.Errorf("invalid method descriptor %s", methodFullName) + return err } - inputName := methodDesc.Input().FullName() methodHandler, err := protocompat.MakeHybridHandler(qrt.binaryCodec, sd, method, handler) if err != nil { return err } - qrt.handlerByMessageName[string(inputName)] = append(qrt.handlerByMessageName[string(inputName)], methodHandler) + qrt.hybridHandlers[string(inputName)] = append(qrt.hybridHandlers[string(inputName)], methodHandler) return nil } diff --git a/baseapp/grpcrouter_test.go b/baseapp/grpcrouter_test.go index a7e338daaf1b..af5c24c4ccce 100644 --- a/baseapp/grpcrouter_test.go +++ b/baseapp/grpcrouter_test.go @@ -56,7 +56,7 @@ func TestGRPCQueryRouter(t *testing.T) { func TestGRPCRouterHybridHandlers(t *testing.T) { assertRouterBehaviour := func(helper *baseapp.QueryServiceTestHelper) { // test getting the handler by name - handlers := helper.GRPCQueryRouter.HandlersByRequestName("testpb.EchoRequest") + handlers := helper.GRPCQueryRouter.HybridHandlerByRequestName("testpb.EchoRequest") require.NotNil(t, handlers) require.Len(t, handlers, 1) handler := handlers[0] diff --git a/baseapp/internal/protocompat/protocompat.go b/baseapp/internal/protocompat/protocompat.go index 17ebac8ad97e..bf61320caa9f 100644 --- a/baseapp/internal/protocompat/protocompat.go +++ b/baseapp/internal/protocompat/protocompat.go @@ -206,3 +206,17 @@ func isProtov2(md grpc.MethodDesc) (isV2Type bool, err error) { _, _ = md.Handler(nil, nil, pullRequestType, doNotExecute) return } + +// RequestFullNameFromMethodDesc returns the fully-qualified name of the request message of the provided service's method. +func RequestFullNameFromMethodDesc(sd *grpc.ServiceDesc, method grpc.MethodDesc) (protoreflect.FullName, error) { + methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName)) + desc, err := gogoproto.HybridResolver.FindDescriptorByName(methodFullName) + if err != nil { + return "", fmt.Errorf("cannot find method descriptor %s", methodFullName) + } + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + if !ok { + return "", fmt.Errorf("invalid method descriptor %s", methodFullName) + } + return methodDesc.Input().FullName(), nil +} diff --git a/baseapp/msg_service_router.go b/baseapp/msg_service_router.go index b31a8dc5a139..506112490220 100644 --- a/baseapp/msg_service_router.go +++ b/baseapp/msg_service_router.go @@ -8,9 +8,12 @@ import ( gogogrpc "github.com/cosmos/gogoproto/grpc" "github.com/cosmos/gogoproto/proto" "google.golang.org/grpc" + "google.golang.org/protobuf/runtime/protoiface" errorsmod "cosmossdk.io/errors" + "github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat" + "github.com/cosmos/cosmos-sdk/codec" codectypes "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -27,6 +30,7 @@ type MessageRouter interface { type MsgServiceRouter struct { interfaceRegistry codectypes.InterfaceRegistry routes map[string]MsgServiceHandler + hybridHandlers map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error circuitBreaker CircuitBreaker } @@ -35,7 +39,8 @@ var _ gogogrpc.Server = &MsgServiceRouter{} // NewMsgServiceRouter creates a new MsgServiceRouter. func NewMsgServiceRouter() *MsgServiceRouter { return &MsgServiceRouter{ - routes: map[string]MsgServiceHandler{}, + routes: map[string]MsgServiceHandler{}, + hybridHandlers: map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error{}, } } @@ -67,113 +72,156 @@ func (msr *MsgServiceRouter) HandlerByTypeURL(typeURL string) MsgServiceHandler func (msr *MsgServiceRouter) RegisterService(sd *grpc.ServiceDesc, handler interface{}) { // Adds a top-level query handler based on the gRPC service name. for _, method := range sd.Methods { - fqMethod := fmt.Sprintf("/%s/%s", sd.ServiceName, method.MethodName) - methodHandler := method.Handler - - var requestTypeName string - - // NOTE: This is how we pull the concrete request type for each handler for registering in the InterfaceRegistry. - // This approach is maybe a bit hacky, but less hacky than reflecting on the handler object itself. - // We use a no-op interceptor to avoid actually calling into the handler itself. - _, _ = methodHandler(nil, context.Background(), func(i interface{}) error { - msg, ok := i.(sdk.Msg) - if !ok { - // We panic here because there is no other alternative and the app cannot be initialized correctly - // this should only happen if there is a problem with code generation in which case the app won't - // work correctly anyway. - panic(fmt.Errorf("unable to register service method %s: %T does not implement sdk.Msg", fqMethod, i)) - } + err := msr.registerMsgServiceHandler(sd, method, handler) + if err != nil { + panic(err) + } + err = msr.registerHybridHandler(sd, method, handler) + if err != nil { + panic(err) + } + } +} + +func (msr *MsgServiceRouter) HybridHandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error { + return msr.hybridHandlers[msgName] +} - requestTypeName = sdk.MsgTypeURL(msg) - return nil - }, noopInterceptor) - - // Check that the service Msg fully-qualified method name has already - // been registered (via RegisterInterfaces). If the user registers a - // service without registering according service Msg type, there might be - // some unexpected behavior down the road. Since we can't return an error - // (`Server.RegisterService` interface restriction) we panic (at startup). - reqType, err := msr.interfaceRegistry.Resolve(requestTypeName) - if err != nil || reqType == nil { - panic( - fmt.Errorf( - "type_url %s has not been registered yet. "+ - "Before calling RegisterService, you must register all interfaces by calling the `RegisterInterfaces` "+ - "method on module.BasicManager. Each module should call `msgservice.RegisterMsgServiceDesc` inside its "+ - "`RegisterInterfaces` method with the `_Msg_serviceDesc` generated by proto-gen", - requestTypeName, - ), - ) +func (msr *MsgServiceRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error { + inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method) + if err != nil { + return err + } + cdc := codec.NewProtoCodec(msr.interfaceRegistry) + hybridHandler, err := protocompat.MakeHybridHandler(cdc, sd, method, handler) + if err != nil { + return err + } + // if circuit breaker is not nil, then we decorate the hybrid handler with the circuit breaker + if msr.circuitBreaker == nil { + msr.hybridHandlers[string(inputName)] = hybridHandler + return nil + } + // decorate the hybrid handler with the circuit breaker + circuitBreakerHybridHandler := func(ctx context.Context, req, resp protoiface.MessageV1) error { + messageName := codectypes.MsgTypeURL(req) + allowed, err := msr.circuitBreaker.IsAllowed(ctx, messageName) + if err != nil { + return err } + if !allowed { + return fmt.Errorf("circuit breaker disallows execution of message %s", messageName) + } + return hybridHandler(ctx, req, resp) + } + msr.hybridHandlers[string(inputName)] = circuitBreakerHybridHandler + return nil +} - // Check that each service is only registered once. If a service is - // registered more than once, then we should error. Since we can't - // return an error (`Server.RegisterService` interface restriction) we - // panic (at startup). - _, found := msr.routes[requestTypeName] - if found { - panic( - fmt.Errorf( - "msg service %s has already been registered. Please make sure to only register each service once. "+ - "This usually means that there are conflicting modules registering the same msg service", - fqMethod, - ), - ) +func (msr *MsgServiceRouter) registerMsgServiceHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error { + fqMethod := fmt.Sprintf("/%s/%s", sd.ServiceName, method.MethodName) + methodHandler := method.Handler + + var requestTypeName string + + // NOTE: This is how we pull the concrete request type for each handler for registering in the InterfaceRegistry. + // This approach is maybe a bit hacky, but less hacky than reflecting on the handler object itself. + // We use a no-op interceptor to avoid actually calling into the handler itself. + _, _ = methodHandler(nil, context.Background(), func(i interface{}) error { + msg, ok := i.(sdk.Msg) + if !ok { + // We panic here because there is no other alternative and the app cannot be initialized correctly + // this should only happen if there is a problem with code generation in which case the app won't + // work correctly anyway. + panic(fmt.Errorf("unable to register service method %s: %T does not implement sdk.Msg", fqMethod, i)) } - msr.routes[requestTypeName] = func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { - ctx = ctx.WithEventManager(sdk.NewEventManager()) - interceptor := func(goCtx context.Context, _ interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - goCtx = context.WithValue(goCtx, sdk.SdkContextKey, ctx) - return handler(goCtx, msg) - } + requestTypeName = sdk.MsgTypeURL(msg) + return nil + }, noopInterceptor) + + // Check that the service Msg fully-qualified method name has already + // been registered (via RegisterInterfaces). If the user registers a + // service without registering according service Msg type, there might be + // some unexpected behavior down the road. Since we can't return an error + // (`Server.RegisterService` interface restriction) we panic (at startup). + reqType, err := msr.interfaceRegistry.Resolve(requestTypeName) + if err != nil || reqType == nil { + return fmt.Errorf( + "type_url %s has not been registered yet. "+ + "Before calling RegisterService, you must register all interfaces by calling the `RegisterInterfaces` "+ + "method on module.BasicManager. Each module should call `msgservice.RegisterMsgServiceDesc` inside its "+ + "`RegisterInterfaces` method with the `_Msg_serviceDesc` generated by proto-gen", + requestTypeName, + ) + } - if m, ok := msg.(sdk.HasValidateBasic); ok { - if err := m.ValidateBasic(); err != nil { - return nil, err - } - } + // Check that each service is only registered once. If a service is + // registered more than once, then we should error. Since we can't + // return an error (`Server.RegisterService` interface restriction) we + // panic (at startup). + _, found := msr.routes[requestTypeName] + if found { + return fmt.Errorf( + "msg service %s has already been registered. Please make sure to only register each service once. "+ + "This usually means that there are conflicting modules registering the same msg service", + fqMethod, + ) + } - if msr.circuitBreaker != nil { - msgURL := sdk.MsgTypeURL(msg) - isAllowed, err := msr.circuitBreaker.IsAllowed(ctx, msgURL) - if err != nil { - return nil, err - } + msr.routes[requestTypeName] = func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { + ctx = ctx.WithEventManager(sdk.NewEventManager()) + interceptor := func(goCtx context.Context, _ interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + goCtx = context.WithValue(goCtx, sdk.SdkContextKey, ctx) + return handler(goCtx, msg) + } - if !isAllowed { - return nil, fmt.Errorf("circuit breaker disables execution of this message: %s", msgURL) - } + if m, ok := msg.(sdk.HasValidateBasic); ok { + if err := m.ValidateBasic(); err != nil { + return nil, err } + } - // Call the method handler from the service description with the handler object. - // We don't do any decoding here because the decoding was already done. - res, err := methodHandler(handler, ctx, noopDecoder, interceptor) + if msr.circuitBreaker != nil { + msgURL := sdk.MsgTypeURL(msg) + isAllowed, err := msr.circuitBreaker.IsAllowed(ctx, msgURL) if err != nil { return nil, err } - resMsg, ok := res.(proto.Message) - if !ok { - return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting proto.Message, got %T", resMsg) + if !isAllowed { + return nil, fmt.Errorf("circuit breaker disables execution of this message: %s", msgURL) } + } - anyResp, err := codectypes.NewAnyWithValue(resMsg) - if err != nil { - return nil, err - } + // Call the method handler from the service description with the handler object. + // We don't do any decoding here because the decoding was already done. + res, err := methodHandler(handler, ctx, noopDecoder, interceptor) + if err != nil { + return nil, err + } - var events []abci.Event - if evtMgr := ctx.EventManager(); evtMgr != nil { - events = evtMgr.ABCIEvents() - } + resMsg, ok := res.(proto.Message) + if !ok { + return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting proto.Message, got %T", resMsg) + } - return &sdk.Result{ - Events: events, - MsgResponses: []*codectypes.Any{anyResp}, - }, nil + anyResp, err := codectypes.NewAnyWithValue(resMsg) + if err != nil { + return nil, err } + + var events []abci.Event + if evtMgr := ctx.EventManager(); evtMgr != nil { + events = evtMgr.ABCIEvents() + } + + return &sdk.Result{ + Events: events, + MsgResponses: []*codectypes.Any{anyResp}, + }, nil } + return nil } // SetInterfaceRegistry sets the interface registry for the router. diff --git a/baseapp/msg_service_router_test.go b/baseapp/msg_service_router_test.go index 4d8cff45a4d0..8ddb490e994d 100644 --- a/baseapp/msg_service_router_test.go +++ b/baseapp/msg_service_router_test.go @@ -86,6 +86,41 @@ func TestRegisterMsgServiceTwice(t *testing.T) { }) } +func TestHybridHandlerByMsgName(t *testing.T) { + // Setup baseapp and router. + var ( + appBuilder *runtime.AppBuilder + registry codectypes.InterfaceRegistry + ) + err := depinject.Inject( + depinject.Configs( + makeMinimalConfig(), + depinject.Supply(log.NewTestLogger(t)), + ), &appBuilder, ®istry) + require.NoError(t, err) + db := dbm.NewMemDB() + app := appBuilder.Build(db, nil) + testdata.RegisterInterfaces(registry) + + testdata.RegisterMsgServer( + app.MsgServiceRouter(), + testdata.MsgServerImpl{}, + ) + + handler := app.MsgServiceRouter().HybridHandlerByMsgName("testpb.MsgCreateDog") + + require.NotNil(t, handler) + require.NoError(t, app.Init()) + ctx := app.NewContext(true) + resp := new(testdata.MsgCreateDogResponse) + err = handler(ctx, &testdata.MsgCreateDog{ + Dog: &testdata.Dog{Name: "Spot"}, + Owner: "me", + }, resp) + require.NoError(t, err) + require.Equal(t, resp.Name, "Spot") +} + func TestMsgService(t *testing.T) { priv, _, _ := testdata.KeyTestPubAddr() diff --git a/server/mock/app.go b/server/mock/app.go index cc9d496b048d..f090f8c6df88 100644 --- a/server/mock/app.go +++ b/server/mock/app.go @@ -10,6 +10,10 @@ import ( abci "github.com/cometbft/cometbft/abci/types" db "github.com/cosmos/cosmos-db" "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" "cosmossdk.io/log" storetypes "cosmossdk.io/store/types" @@ -45,7 +49,7 @@ func NewApp(rootDir string, logger log.Logger) (servertypes.ABCI, error) { router.SetInterfaceRegistry(interfaceRegistry) newDesc := &grpc.ServiceDesc{ - ServiceName: "test", + ServiceName: "Test", Methods: []grpc.MethodDesc{ { MethodName: "Test", @@ -170,3 +174,47 @@ func MsgTestHandler(srv interface{}, ctx context.Context, dec func(interface{}) func (m MsgServerImpl) Test(ctx context.Context, msg *KVStoreTx) (*sdk.Result, error) { return KVStoreHandler(m.capKeyMainStore)(sdk.UnwrapSDKContext(ctx), msg) } + +func init() { + err := registerFauxDescriptor() + if err != nil { + panic(err) + } +} + +func registerFauxDescriptor() error { + fauxDescriptor, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ + Name: proto.String("faux_proto/test.proto"), + Dependency: nil, + PublicDependency: nil, + WeakDependency: nil, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("KVStoreTx"), + }, + }, + EnumType: nil, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("Test"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("Test"), + InputType: proto.String("KVStoreTx"), + OutputType: proto.String("KVStoreTx"), + }, + }, + }, + }, + Extension: nil, + Options: nil, + SourceCodeInfo: nil, + Syntax: nil, + Edition: nil, + }, nil) + if err != nil { + return err + } + + return protoregistry.GlobalFiles.RegisterFile(fauxDescriptor) +}