diff --git a/command.go b/command.go index c3f5aa22..71b15428 100644 --- a/command.go +++ b/command.go @@ -96,6 +96,13 @@ func Identify(js map[string]interface{}) (*Command, error) { return &Command{[]byte("IDENTIFY"), nil, body}, nil } +// Auth sends credentials for authentication +// +// After `Identify`, this is usually the first message sent, if auth is used. +func Auth(secret string) (*Command, error) { + return &Command{[]byte("AUTH"), nil, []byte(secret)}, nil +} + // Register creates a new Command to add a topic/channel for the connected nsqd func Register(topic string, channel string) *Command { params := [][]byte{[]byte(topic)} diff --git a/config.go b/config.go index 05d4dc05..279a02e4 100644 --- a/config.go +++ b/config.go @@ -56,6 +56,8 @@ type Config struct { maxInFlightMutex sync.RWMutex maxBackoffDuration time.Duration `opt:"max_backoff_duration" min:"0" max:"60m"` + + authSecret string `opt:"auth_secret"` } // NewConfig returns a new default configuration @@ -163,6 +165,8 @@ func NewConfig() *Config { // max_backoff_duration: the maximum amount of time to backoff when processing fails // 0 == no backoff // +// auth_secret: Secret for nsqd authentication. (requires nsqd 1.0+) +// func (c *Config) Set(option string, value interface{}) error { c.Lock() defer c.Unlock() diff --git a/conn.go b/conn.go index 6ab4c4d3..7fc35fb0 100644 --- a/conn.go +++ b/conn.go @@ -22,10 +22,19 @@ import ( // IdentifyResponse represents the metadata // returned from an IDENTIFY command to nsqd type IdentifyResponse struct { - MaxRdyCount int64 `json:"max_rdy_count"` - TLSv1 bool `json:"tls_v1"` - Deflate bool `json:"deflate"` - Snappy bool `json:"snappy"` + MaxRdyCount int64 `json:"max_rdy_count"` + TLSv1 bool `json:"tls_v1"` + Deflate bool `json:"deflate"` + Snappy bool `json:"snappy"` + AuthRequired bool `json:"auth_required"` +} + +// AuthResponse represents the metadata +// returned from an AUTH command to nsqd +type AuthResponse struct { + Identity string `json:"identity"` + IdentityUrl string `json:"identity_url"` + PermissionCount int64 `json:"permission_count"` } type msgResponse struct { @@ -133,6 +142,18 @@ func (c *Conn) Connect() (*IdentifyResponse, error) { return nil, err } + if resp != nil && resp.AuthRequired { + if c.config.authSecret == "" { + c.log(LogLevelError, "Auth Required") + return nil, errors.New("Auth Required") + } + err := c.auth(c.config.authSecret) + if err != nil { + c.log(LogLevelError, "Auth Failed %s", err) + return nil, err + } + } + c.wg.Add(2) atomic.StoreInt32(&c.readLoopRunning, 1) go c.readLoop() @@ -376,6 +397,37 @@ func (c *Conn) upgradeSnappy() error { return nil } +func (c *Conn) auth(secret string) error { + cmd, err := Auth(secret) + if err != nil { + return err + } + + err = c.WriteCommand(cmd) + if err != nil { + return err + } + + frameType, data, err := ReadUnpackedResponse(c) + if err != nil { + return err + } + + if frameType == FrameTypeError { + return errors.New("Error authenticating " + string(data)) + } + + resp := &AuthResponse{} + err = json.Unmarshal(data, resp) + if err != nil { + return err + } + + c.log(LogLevelInfo, "Auth accepted. Identity: %q %s Permissions: %d", resp.Identity, resp.IdentityUrl, resp.PermissionCount) + + return nil +} + func (c *Conn) readLoop() { for { if atomic.LoadInt32(&c.closeFlag) == 1 { diff --git a/producer.go b/producer.go index be68b682..410884de 100644 --- a/producer.go +++ b/producer.go @@ -33,9 +33,9 @@ type Producer struct { state int32 concurrentProducers int32 - stopFlag int32 - exitChan chan int - wg sync.WaitGroup + stopFlag int32 + exitChan chan int + wg sync.WaitGroup } // ProducerTransaction is returned by the async publish methods diff --git a/test.sh b/test.sh index 42831322..e8029a76 100755 --- a/test.sh +++ b/test.sh @@ -3,9 +3,19 @@ set -e # a helper script to run tests +if ! which nsqd >/dev/null; then + echo "missing nsqd binary" && exit 1 +fi + +if ! which nsqlookupd >/dev/null; then + echo "missing nsqlookupd binary" && exit 1 +fi + # run nsqlookupd +LOOKUP_LOGFILE=$(mktemp -t nsqlookupd.XXXXXXX) echo "starting nsqlookupd" -nsqlookupd >/dev/null 2>&1 & +echo " logging to $LOOKUP_LOGFILE" +nsqlookupd >$LOOKUP_LOGFILE 2>&1 & LOOKUPD_PID=$! cat >/tmp/cert.pem </dev/null 2>&1 & +echo " logging to $NSQD_LOGFILE" +nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=/tmp/cert.pem --tls-key=/tmp/key.pem >$NSQD_LOGFILE 2>&1 & NSQD_PID=$! sleep 0.3 cleanup() { - kill -s TERM $NSQD_PID - kill -s TERM $LOOKUPD_PID + echo "killing nsqd PID $NSQD_PID" + kill -s TERM $NSQD_PID || cat $NSQD_LOGFILE + echo "killing nsqlookupd PID $LOOKUPD_PID" + kill -s TERM $LOOKUPD_PID || cat $LOOKUP_LOGFILE } trap cleanup INT TERM EXIT