Skip to content

Commit

Permalink
[IMPROVED] Fetch and FetchBatch for draining and closed subscriptions (
Browse files Browse the repository at this point in the history
…#1582)

Signed-off-by: Piotr Piotrowski <[email protected]>
  • Loading branch information
piotrpio authored Mar 18, 2024
1 parent 85e6223 commit 6dfefd9
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 6 deletions.
6 changes: 3 additions & 3 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type (
// It is returned by [Consumer.Messages] method.
MessagesContext interface {
// Next retrieves next message on a stream. It will block until the next
// message is available. If the context is cancelled, Next will return
// message is available. If the context is canceled, Next will return
// ErrMsgIteratorClosed error.
Next() (Msg, error)

Expand Down Expand Up @@ -531,7 +531,7 @@ var (
)

// Next retrieves next message on a stream. It will block until the next
// message is available. If the context is cancelled, Next will return
// message is available. If the context is canceled, Next will return
// ErrMsgIteratorClosed error.
func (s *pullSubscription) Next() (Msg, error) {
s.Lock()
Expand Down Expand Up @@ -1081,7 +1081,7 @@ type backoffOpts struct {
// for all subsequent retries after reaching the limit
customBackoff []time.Duration
// cancel channel
// if set, retry will be cancelled when this channel is closed
// if set, retry will be canceled when this channel is closed
cancel <-chan struct{}
}

Expand Down
16 changes: 14 additions & 2 deletions js.go
Original file line number Diff line number Diff line change
Expand Up @@ -2861,7 +2861,13 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
}
var hbTimer *time.Timer
var hbErr error
if err == nil && len(msgs) < batch {
sub.mu.Lock()
subClosed := sub.closed || sub.draining
sub.mu.Unlock()
if subClosed {
err = errors.Join(ErrBadSubscription, ErrSubscriptionClosed)
}
if err == nil && len(msgs) < batch && !subClosed {
// For batch real size of 1, it does not make sense to set no_wait in
// the request.
noWait := batch-len(msgs) > 1
Expand Down Expand Up @@ -3129,8 +3135,14 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
result.msgs <- msg
}
}
if len(result.msgs) == batch || result.err != nil {
sub.mu.Lock()
subClosed := sub.closed || sub.draining
sub.mu.Unlock()
if len(result.msgs) == batch || result.err != nil || subClosed {
close(result.msgs)
if subClosed && len(result.msgs) == 0 {
return nil, errors.Join(ErrBadSubscription, ErrSubscriptionClosed)
}
result.done <- struct{}{}
return result, nil
}
Expand Down
3 changes: 3 additions & 0 deletions jserrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ var (
// ErrNoHeartbeat is returned when no heartbeat is received from server when sending requests with pull consumer.
ErrNoHeartbeat JetStreamError = &jsError{message: "no heartbeat received"}

// ErrSubscriptionClosed is returned when attempting to send pull request to a closed subscription
ErrSubscriptionClosed JetStreamError = &jsError{message: "subscription closed"}

// DEPRECATED: ErrInvalidDurableName is no longer returned and will be removed in future releases.
// Use ErrInvalidConsumerName instead.
ErrInvalidDurableName = errors.New("nats: invalid durable name")
Expand Down
109 changes: 108 additions & 1 deletion test/js_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,64 @@ func TestPullSubscribeFetchWithHeartbeat(t *testing.T) {
}
}

func TestPullSubscribeFetchDrain(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, s)

nc, js := jsClient(t, s)
defer nc.Close()

_, err := js.AddStream(&nats.StreamConfig{
Name: "TEST",
Subjects: []string{"foo"},
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

defer js.PurgeStream("TEST")
sub, err := js.PullSubscribe("foo", "")
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
for i := 0; i < 100; i++ {
if _, err := js.Publish("foo", []byte("msg")); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
}
// fill buffer with messages
cinfo, err := sub.ConsumerInfo()
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
nextSubject := fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.TEST.%s", cinfo.Name)
replySubject := strings.Replace(sub.Subject, "*", "abc", 1)
payload := `{"batch":10,"no_wait":true}`
if err := nc.PublishRequest(nextSubject, replySubject, []byte(payload)); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
time.Sleep(100 * time.Millisecond)

// now drain the subscription, messages should be in the buffer
sub.Drain()
msgs, err := sub.Fetch(100)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
for _, msg := range msgs {
msg.Ack()
}
if len(msgs) != 10 {
t.Fatalf("Expected %d messages; got: %d", 10, len(msgs))
}

// subsequent fetch should return error, subscription is already drained
_, err = sub.Fetch(10, nats.MaxWait(100*time.Millisecond))
if !errors.Is(err, nats.ErrSubscriptionClosed) {
t.Fatalf("Expected error: %s; got: %s", nats.ErrSubscriptionClosed, err)
}
}

func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, s)
Expand Down Expand Up @@ -1761,6 +1819,55 @@ func TestPullSubscribeFetchBatch(t *testing.T) {
t.Errorf("Expected error: %s; got: %s", nats.ErrNoDeadlineContext, err)
}
})

t.Run("close subscription", func(t *testing.T) {
defer js.PurgeStream("TEST")
sub, err := js.PullSubscribe("foo", "")
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
for i := 0; i < 100; i++ {
if _, err := js.Publish("foo", []byte("msg")); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
}
// fill buffer with messages
cinfo, err := sub.ConsumerInfo()
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
nextSubject := fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.TEST.%s", cinfo.Name)
replySubject := strings.Replace(sub.Subject, "*", "abc", 1)
payload := `{"batch":10,"no_wait":true}`
if err := nc.PublishRequest(nextSubject, replySubject, []byte(payload)); err != nil {
t.Fatalf("Unexpected error: %s", err)
}
time.Sleep(100 * time.Millisecond)

// now drain the subscription, messages should be in the buffer
sub.Drain()
res, err := sub.FetchBatch(100)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
msgs := make([]*nats.Msg, 0)
for msg := range res.Messages() {
msgs = append(msgs, msg)
msg.Ack()
}
if res.Error() != nil {
t.Fatalf("Unexpected error: %s", res.Error())
}
if len(msgs) != 10 {
t.Fatalf("Expected %d messages; got: %d", 10, len(msgs))
}

// subsequent fetch should return error, subscription is already drained
_, err = sub.FetchBatch(10, nats.MaxWait(100*time.Millisecond))
if !errors.Is(err, nats.ErrSubscriptionClosed) {
t.Fatalf("Expected error: %s; got: %s", nats.ErrSubscriptionClosed, err)
}
})
}

func TestPullSubscribeConsumerDeleted(t *testing.T) {
Expand Down Expand Up @@ -7646,7 +7753,7 @@ func testJetStreamFetchOptions(t *testing.T, srvs ...*jsServer) {
if err == nil {
t.Fatal("Unexpected success")
}
if err != nats.ErrBadSubscription {
if !errors.Is(err, nats.ErrBadSubscription) {
t.Fatalf("Unexpected error: %v", err)
}
})
Expand Down

0 comments on commit 6dfefd9

Please sign in to comment.