Skip to content

Commit

Permalink
Fix data races in the example
Browse files Browse the repository at this point in the history
Some users rely on this example as a starting point to their
applications. This commit fixes a data race that could cause issues in
any code that relied on the example as base.

Related to #72

Signed-off-by: Aitor Perez Cedres <[email protected]>
  • Loading branch information
Zerpet authored and lukebakken committed May 7, 2024
1 parent d0e93ff commit c6d8d52
Showing 1 changed file with 49 additions and 16 deletions.
65 changes: 49 additions & 16 deletions example_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ package amqp091_test
import (
"context"
"errors"
"fmt"
"log"
"os"
"sync"
"time"

amqp "github.com/rabbitmq/amqp091-go"
Expand All @@ -23,9 +23,10 @@ import (
// It doesn't automatically ack each message, but leaves that
// to the parent process, since it is usage-dependent.
//
// Try running this in one terminal, and `rabbitmq-server` in another.
// Try running this in one terminal, and rabbitmq-server in another.
//
// Stop & restart RabbitMQ to see how the queue reacts.
func Example() {
func Example_publish() {
queueName := "job_queue"
addr := "amqp://guest:guest@localhost:5672/"
queue := New(queueName, addr)
Expand All @@ -39,12 +40,14 @@ loop:
// Attempt to push a message every 2 seconds
case <-time.After(time.Second * 2):
if err := queue.Push(message); err != nil {
fmt.Printf("Push failed: %s\n", err)
log.Printf("Push failed: %s\n", err)
} else {
fmt.Println("Push succeeded!")
log.Println("Push succeeded!")
}
case <-ctx.Done():
queue.Close()
if err := queue.Close(); err != nil {
log.Printf("Close failed: %s\n", err)
}
break loop
}
}
Expand All @@ -55,15 +58,15 @@ func Example_consume() {
addr := "amqp://guest:guest@localhost:5672/"
queue := New(queueName, addr)

// Give the connection sometime to setup
// Give the connection sometime to set up
<-time.After(time.Second)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

deliveries, err := queue.Consume()
if err != nil {
fmt.Printf("Could not start consuming: %s\n", err)
log.Printf("Could not start consuming: %s\n", err)
return
}

Expand All @@ -78,19 +81,22 @@ func Example_consume() {
for {
select {
case <-ctx.Done():
queue.Close()
err := queue.Close()
if err != nil {
log.Printf("Close failed: %s\n", err)
}
return

case amqErr := <-chClosedCh:
// This case handles the event of closed channel e.g. abnormal shutdown
fmt.Printf("AMQP Channel closed due to: %s\n", amqErr)
log.Printf("AMQP Channel closed due to: %s\n", amqErr)

deliveries, err = queue.Consume()
if err != nil {
// If the AMQP channel is not ready, it will continue the loop. Next
// iteration will enter this case because chClosedCh is closed by the
// library
fmt.Println("Error trying to consume, will try again")
log.Println("Error trying to consume, will try again")
continue
}

Expand All @@ -101,16 +107,21 @@ func Example_consume() {

case delivery := <-deliveries:
// Ack a message every 2 seconds
fmt.Printf("Received message: %s\n", delivery.Body)
log.Printf("Received message: %s\n", delivery.Body)
if err := delivery.Ack(false); err != nil {
fmt.Printf("Error acknowledging message: %s\n", err)
log.Printf("Error acknowledging message: %s\n", err)
}
<-time.After(time.Second * 2)
}
}
}

// Client is the base struct for handling connection recovery, consumption and
// publishing. Note that this struct has an internal mutex to safeguard against
// data races. As you develop and iterate over this example, you may need to add
// further locks, or safeguards, to keep your application safe from data races
type Client struct {
m *sync.Mutex
queueName string
logger *log.Logger
connection *amqp.Connection
Expand Down Expand Up @@ -143,6 +154,7 @@ var (
// attempts to connect to the server.
func New(queueName, addr string) *Client {
client := Client{
m: &sync.Mutex{},
logger: log.New(os.Stdout, "", log.LstdFlags),
queueName: queueName,
done: make(chan bool),
Expand All @@ -155,7 +167,10 @@ func New(queueName, addr string) *Client {
// notifyConnClose, and then continuously attempt to reconnect.
func (client *Client) handleReconnect(addr string) {
for {
client.m.Lock()
client.isReady = false
client.m.Unlock()

client.logger.Println("Attempting to connect")

conn, err := client.connect(addr)
Expand Down Expand Up @@ -194,7 +209,9 @@ func (client *Client) connect(addr string) (*amqp.Connection, error) {
// and then continuously attempt to re-initialize both channels
func (client *Client) handleReInit(conn *amqp.Connection) bool {
for {
client.m.Lock()
client.isReady = false
client.m.Unlock()

err := client.init(conn)

Expand Down Expand Up @@ -251,7 +268,9 @@ func (client *Client) init(conn *amqp.Connection) error {
}

client.changeChannel(ch)
client.m.Lock()
client.isReady = true
client.m.Unlock()
client.logger.Println("Setup!")

return nil
Expand All @@ -275,13 +294,16 @@ func (client *Client) changeChannel(channel *amqp.Channel) {
client.channel.NotifyPublish(client.notifyConfirm)
}

// Push will push data onto the queue, and wait for a confirm.
// This will block until the server sends a confirm. Errors are
// Push will push data onto the queue, and wait for a confirmation.
// This will block until the server sends a confirmation. Errors are
// only returned if the push action itself fails, see UnsafePush.
func (client *Client) Push(data []byte) error {
client.m.Lock()
if !client.isReady {
client.m.Unlock()
return errors.New("failed to push: not connected")
}
client.m.Unlock()
for {
err := client.UnsafePush(data)
if err != nil {
Expand All @@ -306,9 +328,12 @@ func (client *Client) Push(data []byte) error {
// No guarantees are provided for whether the server will
// receive the message.
func (client *Client) UnsafePush(data []byte) error {
client.m.Lock()
if !client.isReady {
client.m.Unlock()
return errNotConnected
}
client.m.Unlock()

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
Expand All @@ -331,13 +356,16 @@ func (client *Client) UnsafePush(data []byte) error {
// successfully processed, or delivery.Nack when it fails.
// Ignoring this will cause data to build up on the server.
func (client *Client) Consume() (<-chan amqp.Delivery, error) {
client.m.Lock()
if !client.isReady {
client.m.Unlock()
return nil, errNotConnected
}
client.m.Unlock()

if err := client.channel.Qos(
1, // prefetchCount
0, // prefrechSize
0, // prefetchSize
false, // global
); err != nil {
return nil, err
Expand All @@ -356,6 +384,11 @@ func (client *Client) Consume() (<-chan amqp.Delivery, error) {

// Close will cleanly shut down the channel and connection.
func (client *Client) Close() error {
client.m.Lock()
// we read and write isReady in two locations, so we grab the lock and hold onto
// it until we are finished
defer client.m.Unlock()

if !client.isReady {
return errAlreadyClosed
}
Expand Down

0 comments on commit c6d8d52

Please sign in to comment.