diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go index 132875cc167a..f34b46ddbb64 100644 --- a/pubsub/pstest/fake.go +++ b/pubsub/pstest/fake.go @@ -288,6 +288,9 @@ func (s *Server) ClearMessages() { s.GServer.mu.Lock() s.GServer.msgs = nil s.GServer.msgsByID = make(map[string]*Message) + for _, sub := range s.GServer.subs { + sub.msgs = map[string]*message{} + } s.GServer.mu.Unlock() } diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go index b44988d3f3f7..871dfd0a0625 100644 --- a/pubsub/pstest/fake_test.go +++ b/pubsub/pstest/fake_test.go @@ -472,13 +472,19 @@ func TestPublishOrdered(t *testing.T) { } func TestClearMessages(t *testing.T) { - s := NewServer() - defer s.Close() + pclient, sclient, s, cleanup := newFake(context.TODO(), t) + defer cleanup() + + top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"}) + sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{ + Name: "projects/P/subscriptions/S", + Topic: top.Name, + AckDeadlineSeconds: 10, + }) for i := 0; i < 3; i++ { - s.Publish("projects/p/topics/t", []byte("hello"), nil) + s.Publish(top.Name, []byte("hello"), nil) } - s.Wait() msgs := s.Messages() if got, want := len(msgs), 3; got != want { t.Errorf("got %d messages, want %d", got, want) @@ -488,6 +494,14 @@ func TestClearMessages(t *testing.T) { if got, want := len(msgs), 0; got != want { t.Errorf("got %d messages, want %d", got, want) } + + res, err := sclient.Pull(context.Background(), &pb.PullRequest{Subscription: sub.Name}) + if err != nil { + t.Fatal(err) + } + if len(res.ReceivedMessages) != 0 { + t.Errorf("got %d messages, want zero", len(res.ReceivedMessages)) + } } // Note: this sets the fake's "now" time, so it is sensitive to concurrent changes to "now".