From c40c9ba315aaaa58eb4dc74ba9d8bbc5058f1dfa Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Tue, 10 Oct 2023 11:42:11 -0700 Subject: [PATCH] server: prohibit more than MaxConcurrentStreams handlers from running at once (#6703) (#6705) --- benchmark/primitives/primitives_test.go | 39 ++++++++++ internal/transport/http2_server.go | 11 +-- internal/transport/transport_test.go | 35 +++++---- server.go | 71 ++++++++++++------ server_ext_test.go | 99 +++++++++++++++++++++++++ 5 files changed, 210 insertions(+), 45 deletions(-) create mode 100644 server_ext_test.go diff --git a/benchmark/primitives/primitives_test.go b/benchmark/primitives/primitives_test.go index dbbb313e8dcc..5d7c81090c15 100644 --- a/benchmark/primitives/primitives_test.go +++ b/benchmark/primitives/primitives_test.go @@ -425,3 +425,42 @@ func BenchmarkRLockUnlock(b *testing.B) { } }) } + +type ifNop interface { + nop() +} + +type alwaysNop struct{} + +func (alwaysNop) nop() {} + +type concreteNop struct { + isNop atomic.Bool + i int +} + +func (c *concreteNop) nop() { + if c.isNop.Load() { + return + } + c.i++ +} + +func BenchmarkInterfaceNop(b *testing.B) { + n := ifNop(alwaysNop{}) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + n.nop() + } + }) +} + +func BenchmarkConcreteNop(b *testing.B) { + n := &concreteNop{} + n.isNop.Store(true) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + n.nop() + } + }) +} diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 8d3a353c1d58..c06db679d89c 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, ID: http2.SettingMaxFrameSize, Val: http2MaxFrameLen, }} - // TODO(zhaoq): Have a better way to signal "no limit" because 0 is - // permitted in the HTTP2 spec. - maxStreams := config.MaxStreams - if maxStreams == 0 { - maxStreams = math.MaxUint32 - } else { + if config.MaxStreams != math.MaxUint32 { isettings = append(isettings, http2.Setting{ ID: http2.SettingMaxConcurrentStreams, - Val: maxStreams, + Val: config.MaxStreams, }) } dynamicWindow := true @@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, framer: framer, readerDone: make(chan struct{}), writerDone: make(chan struct{}), - maxStreams: maxStreams, + maxStreams: config.MaxStreams, inTapHandle: config.InTapHandle, fc: &trInFlow{limit: uint32(icwz)}, state: reachable, diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 258ef7411cf0..bb27a6b63e98 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -337,6 +337,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT return } rawConn := conn + if serverConfig.MaxStreams == 0 { + serverConfig.MaxStreams = math.MaxUint32 + } transport, err := NewServerTransport(conn, serverConfig) if err != nil { return @@ -443,8 +446,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server return server } -func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) { - return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{}) +func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) { + return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{}) } func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) { @@ -539,7 +542,7 @@ func (s) TestInflightStreamClosing(t *testing.T) { // Tests that when streamID > MaxStreamId, the current client transport drains. func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() defer server.stop() callHdr := &CallHdr{ @@ -584,7 +587,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) { } func (s) TestClientSendAndReceive(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -624,7 +627,7 @@ func (s) TestClientSendAndReceive(t *testing.T) { } func (s) TestClientErrorNotify(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() go server.stop() // ct.reader should detect the error and activate ct.Error(). @@ -658,7 +661,7 @@ func performOneRPC(ct ClientTransport) { } func (s) TestClientMix(t *testing.T) { - s, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + s, ct, cancel := setUp(t, 0, normal) defer cancel() time.AfterFunc(time.Second, s.stop) go func(ct ClientTransport) { @@ -672,7 +675,7 @@ func (s) TestClientMix(t *testing.T) { } func (s) TestLargeMessage(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -807,7 +810,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) { // proceed until they complete naturally, while not allowing creation of new // streams during this window. func (s) TestGracefulClose(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong) + server, ct, cancel := setUp(t, 0, pingpong) defer cancel() defer func() { // Stop the server's listener to make the server's goroutines terminate @@ -873,7 +876,7 @@ func (s) TestGracefulClose(t *testing.T) { } func (s) TestLargeMessageSuspension(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended) + server, ct, cancel := setUp(t, 0, suspended) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -981,7 +984,7 @@ func (s) TestMaxStreams(t *testing.T) { } func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended) + server, ct, cancel := setUp(t, 0, suspended) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -1453,7 +1456,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) { var encodingTestStatus = status.New(codes.Internal, "\n") func (s) TestEncodingRequiredStatus(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) + server, ct, cancel := setUp(t, 0, encodingRequiredStatus) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -1481,7 +1484,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) { } func (s) TestInvalidHeaderField(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) + server, ct, cancel := setUp(t, 0, invalidHeaderField) defer cancel() callHdr := &CallHdr{ Host: "localhost", @@ -1503,7 +1506,7 @@ func (s) TestInvalidHeaderField(t *testing.T) { } func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) { - server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField) + server, ct, cancel := setUp(t, 0, invalidHeaderField) defer cancel() defer server.stop() defer ct.Close(fmt.Errorf("closed manually by test")) @@ -2171,7 +2174,7 @@ func (s) TestPingPong1MB(t *testing.T) { // This is a stress-test of flow control logic. func runPingPongTest(t *testing.T, msgSize int) { - server, client, cancel := setUp(t, 0, 0, pingpong) + server, client, cancel := setUp(t, 0, pingpong) defer cancel() defer server.stop() defer client.Close(fmt.Errorf("closed manually by test")) @@ -2253,7 +2256,7 @@ func (s) TestHeaderTblSize(t *testing.T) { } }() - server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + server, ct, cancel := setUp(t, 0, normal) defer cancel() defer ct.Close(fmt.Errorf("closed manually by test")) defer server.stop() @@ -2612,7 +2615,7 @@ func TestConnectionError_Unwrap(t *testing.T) { func (s) TestPeerSetInServerContext(t *testing.T) { // create client and server transports. - server, client, cancel := setUp(t, 0, math.MaxUint32, normal) + server, client, cancel := setUp(t, 0, normal) defer cancel() defer server.stop() defer client.Close(fmt.Errorf("closed manually by test")) diff --git a/server.go b/server.go index 244123c6c5a8..eeae92fbe020 100644 --- a/server.go +++ b/server.go @@ -115,12 +115,6 @@ type serviceInfo struct { mdata any } -type serverWorkerData struct { - st transport.ServerTransport - wg *sync.WaitGroup - stream *transport.Stream -} - // Server is a gRPC server to serve RPC requests. type Server struct { opts serverOptions @@ -145,7 +139,7 @@ type Server struct { channelzID *channelz.Identifier czData *channelzData - serverWorkerChannel chan *serverWorkerData + serverWorkerChannel chan func() } type serverOptions struct { @@ -179,6 +173,7 @@ type serverOptions struct { } var defaultServerOptions = serverOptions{ + maxConcurrentStreams: math.MaxUint32, maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, maxSendMessageSize: defaultServerMaxSendMessageSize, connectionTimeout: 120 * time.Second, @@ -404,6 +399,9 @@ func MaxSendMsgSize(m int) ServerOption { // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // of concurrent streams to each ServerTransport. func MaxConcurrentStreams(n uint32) ServerOption { + if n == 0 { + n = math.MaxUint32 + } return newFuncServerOption(func(o *serverOptions) { o.maxConcurrentStreams = n }) @@ -605,24 +603,19 @@ const serverWorkerResetThreshold = 1 << 16 // [1] https://github.com/golang/go/issues/18138 func (s *Server) serverWorker() { for completed := 0; completed < serverWorkerResetThreshold; completed++ { - data, ok := <-s.serverWorkerChannel + f, ok := <-s.serverWorkerChannel if !ok { return } - s.handleSingleStream(data) + f() } go s.serverWorker() } -func (s *Server) handleSingleStream(data *serverWorkerData) { - defer data.wg.Done() - s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream)) -} - // initServerWorkers creates worker goroutines and a channel to process incoming // connections to reduce the time spent overall on runtime.morestack. func (s *Server) initServerWorkers() { - s.serverWorkerChannel = make(chan *serverWorkerData) + s.serverWorkerChannel = make(chan func()) for i := uint32(0); i < s.opts.numServerWorkers; i++ { go s.serverWorker() } @@ -982,21 +975,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) { defer st.Close(errors.New("finished serving streams for the server transport")) var wg sync.WaitGroup + streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) st.HandleStreams(func(stream *transport.Stream) { wg.Add(1) + + streamQuota.acquire() + f := func() { + defer streamQuota.release() + defer wg.Done() + s.handleStream(st, stream, s.traceInfo(st, stream)) + } + if s.opts.numServerWorkers > 0 { - data := &serverWorkerData{st: st, wg: &wg, stream: stream} select { - case s.serverWorkerChannel <- data: + case s.serverWorkerChannel <- f: return default: // If all stream workers are busy, fallback to the default code path. } } - go func() { - defer wg.Done() - s.handleStream(st, stream, s.traceInfo(st, stream)) - }() + go f() }, func(ctx context.Context, method string) context.Context { if !EnableTracing { return ctx @@ -2091,3 +2089,34 @@ func validateSendCompressor(name, clientCompressors string) error { } return fmt.Errorf("client does not support compressor %q", name) } + +// atomicSemaphore implements a blocking, counting semaphore. acquire should be +// called synchronously; release may be called asynchronously. +type atomicSemaphore struct { + n atomic.Int64 + wait chan struct{} +} + +func (q *atomicSemaphore) acquire() { + if q.n.Add(-1) < 0 { + // We ran out of quota. Block until a release happens. + <-q.wait + } +} + +func (q *atomicSemaphore) release() { + // N.B. the "<= 0" check below should allow for this to work with multiple + // concurrent calls to acquire, but also note that with synchronous calls to + // acquire, as our system does, n will never be less than -1. There are + // fairness issues (queuing) to consider if this was to be generalized. + if q.n.Add(1) <= 0 { + // An acquire was waiting on us. Unblock it. + q.wait <- struct{}{} + } +} + +func newHandlerQuota(n uint32) *atomicSemaphore { + a := &atomicSemaphore{wait: make(chan struct{}, 1)} + a.n.Store(int64(n)) + return a +} diff --git a/server_ext_test.go b/server_ext_test.go new file mode 100644 index 000000000000..df79755f3255 --- /dev/null +++ b/server_ext_test.go @@ -0,0 +1,99 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc_test + +import ( + "context" + "io" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/stubserver" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" +) + +// TestServer_MaxHandlers ensures that no more than MaxConcurrentStreams server +// handlers are active at one time. +func (s) TestServer_MaxHandlers(t *testing.T) { + started := make(chan struct{}) + blockCalls := grpcsync.NewEvent() + + // This stub server does not properly respect the stream context, so it will + // not exit when the context is canceled. + ss := stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + started <- struct{}{} + <-blockCalls.Done() + return nil + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Start one RPC to the server. + ctx1, cancel1 := context.WithCancel(ctx) + _, err := ss.Client.FullDuplexCall(ctx1) + if err != nil { + t.Fatal("Error staring call:", err) + } + + // Wait for the handler to be invoked. + select { + case <-started: + case <-ctx.Done(): + t.Fatalf("Timed out waiting for RPC to start on server.") + } + + // Cancel it on the client. The server handler will still be running. + cancel1() + + ctx2, cancel2 := context.WithCancel(ctx) + defer cancel2() + s, err := ss.Client.FullDuplexCall(ctx2) + if err != nil { + t.Fatal("Error staring call:", err) + } + + // After 100ms, allow the first call to unblock. That should allow the + // second RPC to run and finish. + select { + case <-started: + blockCalls.Fire() + t.Fatalf("RPC started unexpectedly.") + case <-time.After(100 * time.Millisecond): + blockCalls.Fire() + } + + select { + case <-started: + case <-ctx.Done(): + t.Fatalf("Timed out waiting for second RPC to start on server.") + } + if _, err := s.Recv(); err != io.EOF { + t.Fatal("Received unexpected RPC error:", err) + } +}