Skip to content

Commit

Permalink
[FIXED] Fix pull heartbeat validation
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <[email protected]>
  • Loading branch information
piotrpio committed Sep 25, 2023
1 parent 67a55c2 commit 895e542
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 10 deletions.
22 changes: 16 additions & 6 deletions js.go
Original file line number Diff line number Diff line change
Expand Up @@ -2756,9 +2756,6 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
ttl = js.opts.wait
}
sub.mu.Unlock()
if o.hb != 0 && 2*o.hb >= ttl {
return nil, fmt.Errorf("%w: idle heartbeat value too large", ErrInvalidArg)
}

// Use the given context or setup a default one for the span
// of the pull batch request.
Expand All @@ -2784,6 +2781,14 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
}
defer cancel()

// if heartbeat is set, validate it against the context timeout
if o.hb > 0 {
deadline, _ := ctx.Deadline()
if 2*o.hb >= time.Until(deadline) {
return nil, fmt.Errorf("%w: idle heartbeat value too large", ErrInvalidArg)
}
}

// Check if context not done already before making the request.
select {
case <-ctx.Done():
Expand Down Expand Up @@ -3017,9 +3022,6 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
ttl = js.opts.wait
}
sub.mu.Unlock()
if o.hb != 0 && 2*o.hb >= ttl {
return nil, fmt.Errorf("%w: idle heartbeat value too large", ErrInvalidArg)
}

// Use the given context or setup a default one for the span
// of the pull batch request.
Expand Down Expand Up @@ -3050,6 +3052,14 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e
}
}()

// if heartbeat is set, validate it against the context timeout
if o.hb > 0 {
deadline, _ := ctx.Deadline()
if 2*o.hb >= time.Until(deadline) {
return nil, fmt.Errorf("%w: idle heartbeat value too large", ErrInvalidArg)
}
}

// Check if context not done already before making the request.
select {
case <-ctx.Done():
Expand Down
87 changes: 83 additions & 4 deletions test/js_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ func TestPullSubscribeFetchWithHeartbeat(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
defer sub.Unsubscribe()
for i := 0; i < 5; i++ {
if _, err := js.Publish("foo", []byte("msg")); err != nil {
t.Fatalf("Unexpected error: %s", err)
Expand Down Expand Up @@ -1184,13 +1185,51 @@ func TestPullSubscribeFetchWithHeartbeat(t *testing.T) {
// heartbeat value too large
_, err = sub.Fetch(5, nats.PullHeartbeat(200*time.Millisecond), nats.MaxWait(300*time.Millisecond))
if !errors.Is(err, nats.ErrInvalidArg) {
t.Fatalf("Expected no heartbeat error; got: %v", err)
t.Fatalf("Expected invalid arg error; got: %v", err)
}

// heartbeat value invalid
_, err = sub.Fetch(5, nats.PullHeartbeat(-1))
if !errors.Is(err, nats.ErrInvalidArg) {
t.Fatalf("Expected no heartbeat error; got: %v", err)
t.Fatalf("Expected invalid arg error; got: %v", err)
}

// set short timeout on JetStream context
js, err = nc.JetStream(nats.MaxWait(100 * time.Millisecond))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

sub1, err := js.PullSubscribe("foo", "")
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
defer sub.Unsubscribe()

// should produce invalid arg error based on default timeout from JetStream context
_, err = sub1.Fetch(5, nats.PullHeartbeat(100*time.Millisecond))
if !errors.Is(err, nats.ErrInvalidArg) {
t.Fatalf("Expected invalid arg error; got: %v", err)
}

// overwrite default timeout with context timeout, fetch available messages
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
defer cancel()
msgs, err = sub1.Fetch(10, nats.PullHeartbeat(100*time.Millisecond), nats.Context(ctx))
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if len(msgs) != 5 {
t.Fatalf("Expected %d messages; got: %d", 5, len(msgs))
}
for _, msg := range msgs {
msg.Ack()
}

// overwrite default timeout with max wait, should time out because no messages are available
_, err = sub1.Fetch(5, nats.PullHeartbeat(100*time.Millisecond), nats.MaxWait(300*time.Millisecond))
if !errors.Is(err, nats.ErrTimeout) {
t.Fatalf("Expected timeout error; got: %v", err)
}
}

Expand All @@ -1213,6 +1252,7 @@ func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
defer sub.Unsubscribe()
for i := 0; i < 5; i++ {
if _, err := js.Publish("foo", []byte("msg")); err != nil {
t.Fatalf("Unexpected error: %s", err)
Expand Down Expand Up @@ -1282,16 +1322,55 @@ func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) {
}

// heartbeat value too large
_, err = sub.Fetch(5, nats.PullHeartbeat(200*time.Millisecond), nats.MaxWait(300*time.Millisecond))
_, err = sub.FetchBatch(5, nats.PullHeartbeat(200*time.Millisecond), nats.MaxWait(300*time.Millisecond))
if !errors.Is(err, nats.ErrInvalidArg) {
t.Fatalf("Expected no heartbeat error; got: %v", err)
}

// heartbeat value invalid
_, err = sub.Fetch(5, nats.PullHeartbeat(-1))
_, err = sub.FetchBatch(5, nats.PullHeartbeat(-1))
if !errors.Is(err, nats.ErrInvalidArg) {
t.Fatalf("Expected no heartbeat error; got: %v", err)
}

// set short timeout on JetStream context
js, err = nc.JetStream(nats.MaxWait(100 * time.Millisecond))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

sub1, err := js.PullSubscribe("foo", "")
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
defer sub.Unsubscribe()

// should produce invalid arg error based on default timeout from JetStream context
_, err = sub1.Fetch(5, nats.PullHeartbeat(100*time.Millisecond))
if !errors.Is(err, nats.ErrInvalidArg) {
t.Fatalf("Expected invalid arg error; got: %v", err)
}

// overwrite default timeout with context timeout, fetch available messages
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
defer cancel()
msgs, err = sub1.FetchBatch(10, nats.PullHeartbeat(100*time.Millisecond), nats.Context(ctx))
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
for msg := range msgs.Messages() {
msg.Ack()
}

// overwrite default timeout with max wait, should time out because no messages are available
msgs, err = sub1.FetchBatch(5, nats.PullHeartbeat(100*time.Millisecond), nats.MaxWait(300*time.Millisecond))
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
<-msgs.Done()
if msgs.Error() != nil {
t.Fatalf("Unexpected error: %s", msgs.Error())
}
}

func TestPullSubscribeFetchBatch(t *testing.T) {
Expand Down

0 comments on commit 895e542

Please sign in to comment.