diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go index 6d993bd3aaee..132875cc167a 100644 --- a/pubsub/pstest/fake.go +++ b/pubsub/pstest/fake.go @@ -1037,10 +1037,10 @@ func (s *subscription) pull(max int) []*pb.ReceivedMessage { s.publishToDeadLetter(m) continue } + (*m.deliveries)++ if s.proto.DeadLetterPolicy != nil { m.proto.DeliveryAttempt = int32(*m.deliveries) } - (*m.deliveries)++ m.ackDeadline = now.Add(s.ackTimeout) msgs = append(msgs, m.proto) if len(msgs) >= max { @@ -1122,6 +1122,13 @@ func (s *subscription) deliver() { // // Must be called with the lock held. func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) { + // Optimistically increment DeliveryAttempt assuming we'll be able to deliver the message. This is + // safe since the lock is held for the duration of this function, and the channel receiver does not + // modify the message. + if s.proto.DeadLetterPolicy != nil { + m.proto.DeliveryAttempt = int32(*m.deliveries) + 1 + } + for i := 0; i < len(s.streams); i++ { idx := (i + start) % len(s.streams) @@ -1139,6 +1146,10 @@ func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) ( default: } } + // Restore the correct value of DeliveryAttempt if we were not able to deliver the message. + if s.proto.DeadLetterPolicy != nil { + m.proto.DeliveryAttempt = int32(*m.deliveries) + } return 0, false } diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go index 8757254d696c..b44988d3f3f7 100644 --- a/pubsub/pstest/fake_test.go +++ b/pubsub/pstest/fake_test.go @@ -346,8 +346,8 @@ func TestSubscriptionDeadLetter(t *testing.T) { } for _, m := range pull.ReceivedMessages { - if int32(i) != m.DeliveryAttempt { - t.Fatalf("message delivery attempt not the expected one. expected: %d, actual: %d", i, m.DeliveryAttempt) + if int32(i+1) != m.DeliveryAttempt { + t.Fatalf("message delivery attempt not the expected one. expected: %d, actual: %d", i+1, m.DeliveryAttempt) } _, err := server.GServer.ModifyAckDeadline(ctx, &pb.ModifyAckDeadlineRequest{ Subscription: sub.Name, @@ -551,11 +551,19 @@ func TestStreamingPull(t *testing.T) { pclient, sclient, srv, cleanup := newFake(context.TODO(), t) defer cleanup() + deadLetterTopic := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{ + Name: "projects/P/topics/deadLetter", + }) + top := mustCreateTopic(context.TODO(), t, pclient, &pb.Topic{Name: "projects/P/topics/T"}) sub := mustCreateSubscription(context.TODO(), t, sclient, &pb.Subscription{ Name: "projects/P/subscriptions/S", Topic: top.Name, AckDeadlineSeconds: 10, + DeadLetterPolicy: &pb.DeadLetterPolicy{ + DeadLetterTopic: deadLetterTopic.Name, + MaxDeliveryAttempts: 3, + }, }) want := publish(t, srv, pclient, top, []*pb.PubsubMessage{ @@ -563,7 +571,13 @@ func TestStreamingPull(t *testing.T) { {Data: []byte("d2")}, {Data: []byte("d3")}, }) - got := pubsubMessages(streamingPullN(context.TODO(), t, len(want), sclient, sub)) + received := streamingPullN(context.TODO(), t, len(want), sclient, sub) + for _, m := range received { + if m.DeliveryAttempt != 1 { + t.Errorf("got DeliveryAttempt==%d, want 1", m.DeliveryAttempt) + } + } + got := pubsubMessages(received) if diff := testutil.Diff(got, want); diff != "" { t.Error(diff) }