diff --git a/channel.go b/channel.go index 82be419..a4c061c 100644 --- a/channel.go +++ b/channel.go @@ -70,4 +70,5 @@ func (c *Channel) removeClient(client *Client) { delete(c.clients, client) c.mu.Unlock() close(client.send) + client.removed <- struct{}{} } diff --git a/client.go b/client.go index af026e1..5884c55 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,8 @@ package sse type Client struct { lastEventID, channel string - send chan *Message + send chan *Message + removed chan struct{} } func newClient(lastEventID, channel string) *Client { @@ -12,6 +13,7 @@ func newClient(lastEventID, channel string) *Client { lastEventID, channel, make(chan *Message), + make(chan struct{}, 1), } } diff --git a/sse.go b/sse.go index 99c7ee8..4550ee9 100644 --- a/sse.go +++ b/sse.go @@ -83,7 +83,10 @@ func (s *Server) ServeHTTP(response http.ResponseWriter, request *http.Request) go func() { <-closeNotify - s.removeClient <- c + select { + case s.removeClient <- c: + case <-c.removed: + } }() response.WriteHeader(http.StatusOK) @@ -253,7 +256,6 @@ func (s *Server) dispatch() { case <-s.shutdown: s.close() close(s.addClient) - close(s.removeClient) close(s.closeChannel) close(s.shutdown) diff --git a/sse_test.go b/sse_test.go index 1241d16..6d674af 100644 --- a/sse_test.go +++ b/sse_test.go @@ -1,8 +1,12 @@ package sse import ( + "context" "fmt" + "go.uber.org/goleak" + "io/ioutil" "log" + "net/http" "os" "sync" "testing" @@ -87,3 +91,39 @@ func TestServer(t *testing.T) { t.Errorf("Expected %d messages but got %d", channelCount*clientCount, messageCount) } } + +func TestShutdown(t *testing.T) { + defer goleak.VerifyNone(t) + + srv := NewServer(nil) + + http.Handle("/events/", srv) + + httpServer := &http.Server{Addr: ":3000", Handler: nil} + + go func() { _ = httpServer.ListenAndServe() }() + + stop := make(chan struct{}) + + go func() { + r, err := http.Get("http://localhost:3000/events/chan") + if err != nil { + log.Fatalln(err) + return + } + // Stop while client is reading the response + stop <- struct{}{} + _, _ = ioutil.ReadAll(r.Body) + }() + + <-stop + + srv.Shutdown() + + ctx, done := context.WithTimeout(context.Background(), 600*time.Millisecond) + err := httpServer.Shutdown(ctx) + if err != nil { + log.Println(err) + } + done() +}