Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lookupd: support filtering response based on auth state #370

Merged
merged 3 commits into from
Jun 17, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Godeps
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
{
"ImportPath": "github.com/bitly/go-nsq",
"Comment": "v1.0.0-alpha",
"Rev": "cb45642b33d79977f329c6dea6b06a60ed44a317"
"Rev": "23d799909149ced627bb259d61625edc36bd8a20"
},
{
"ImportPath": "github.com/bitly/go-simplejson",
Expand Down
2 changes: 1 addition & 1 deletion apps/nsq_pubsub/nsq_pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (s *StreamServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

cfg := nsq.NewConfig()
cfg.Set("max_in_flight", *maxInFlight)
cfg.MaxInFlight = *maxInFlight
r, err := nsq.NewConsumer(topicName, channelName, cfg)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
2 changes: 1 addition & 1 deletion apps/nsq_tail/nsq_tail.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func main() {
}

cfg := nsq.NewConfig()
cfg.Set("max_in_flight", *maxInFlight)
cfg.MaxInFlight = *maxInFlight

err := util.ParseReaderOpts(cfg, readerOpts)
if err != nil {
Expand Down
8 changes: 1 addition & 7 deletions apps/nsq_to_file/nsq_to_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,7 @@ func newReaderFileLogger(topic string) (*ReaderFileLogger, error) {
if err != nil {
return nil, err
}
cfg.Set("max_in_flight", *maxInFlight)

// TODO: remove, deprecated
if hasArg("verbose") {
log.Printf("WARNING: --verbose is deprecated in favor of --reader-opt=verbose")
cfg.Set("verbose", true)
}
cfg.MaxInFlight = *maxInFlight

r, err := nsq.NewConsumer(topic, *channel, cfg)
if err != nil {
Expand Down
10 changes: 2 additions & 8 deletions apps/nsq_to_http/nsq_to_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,12 @@ func main() {
if err != nil {
log.Fatalf(err.Error())
}
cfg.Set("max_in_flight", *maxInFlight)

// TODO: remove, deprecated
if hasArg("verbose") {
log.Printf("WARNING: --verbose is deprecated in favor of --reader-opt=verbose")
cfg.Set("verbose", true)
}
cfg.MaxInFlight = *maxInFlight

// TODO: remove, deprecated
if hasArg("max-backoff-duration") {
log.Printf("WARNING: --max-backoff-duration is deprecated in favor of --reader-opt=max_backoff_duration=X")
cfg.Set("max_backoff_duration", *maxBackoffDuration)
cfg.MaxBackoffDuration = *maxBackoffDuration
}

r, err := nsq.NewConsumer(*topic, *channel, cfg)
Expand Down
17 changes: 7 additions & 10 deletions apps/nsq_to_nsq/nsq_to_nsq.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,12 @@ func main() {
if err != nil {
log.Fatalf(err.Error())
}
cfg.Set("max_in_flight", *maxInFlight)

// TODO: remove, deprecated
if hasArg("verbose") {
log.Printf("WARNING: --verbose is deprecated in favor of --reader-opt=verbose")
cfg.Set("verbose", true)
}
cfg.MaxInFlight = *maxInFlight

// TODO: remove, deprecated
if hasArg("max-backoff-duration") {
log.Printf("WARNING: --max-backoff-duration is deprecated in favor of --reader-opt=max_backoff_duration=X")
cfg.Set("max_backoff_duration", *maxBackoffDuration)
cfg.MaxBackoffDuration = *maxBackoffDuration
}

r, err := nsq.NewConsumer(*topic, *channel, cfg)
Expand All @@ -376,10 +370,13 @@ func main() {
r.SetLogger(log.New(os.Stderr, "", log.LstdFlags), nsq.LogLevelInfo)

wcfg := nsq.NewConfig()
wcfg.Set("heartbeat_interval", nsq.DefaultClientTimeout/2)
producers := make(map[string]*nsq.Producer)
for _, addr := range destNsqdTCPAddrs {
producers[addr] = nsq.NewProducer(addr, wcfg)
producer, err := nsq.NewProducer(addr, wcfg)
if err != nil {
log.Fatalf("failed creating producer %s", err)
}
producers[addr] = producer
}

handler := &PublishHandler{
Expand Down
17 changes: 9 additions & 8 deletions apps/nsqd/nsqd.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ var (
flagSet = flag.NewFlagSet("nsqd", flag.ExitOnError)

// basic options
config = flagSet.String("config", "", "path to config file")
showVersion = flagSet.Bool("version", false, "print version string")
verbose = flagSet.Bool("verbose", false, "enable verbose logging")
workerId = flagSet.Int64("worker-id", 0, "unique identifier (int) for this worker (will default to a hash of hostname)")
httpsAddress = flagSet.String("https-address", "", "<addr>:<port> to listen on for HTTPS clients")
httpAddress = flagSet.String("http-address", "0.0.0.0:4151", "<addr>:<port> to listen on for HTTP clients")
tcpAddress = flagSet.String("tcp-address", "0.0.0.0:4150", "<addr>:<port> to listen on for TCP clients")
authHttpAddress = flagSet.String("auth-http-address", "", "<addr>:<port> to query auth server")
config = flagSet.String("config", "", "path to config file")
showVersion = flagSet.Bool("version", false, "print version string")
verbose = flagSet.Bool("verbose", false, "enable verbose logging")
workerId = flagSet.Int64("worker-id", 0, "unique identifier (int) for this worker (will default to a hash of hostname)")
httpsAddress = flagSet.String("https-address", "", "<addr>:<port> to listen on for HTTPS clients")
httpAddress = flagSet.String("http-address", "0.0.0.0:4151", "<addr>:<port> to listen on for HTTP clients")
tcpAddress = flagSet.String("tcp-address", "0.0.0.0:4150", "<addr>:<port> to listen on for TCP clients")
authHttpAddresses = util.StringArray{}

broadcastAddress = flagSet.String("broadcast-address", "", "address that will be registered with lookupd (defaults to the OS hostname)")
lookupdTCPAddrs = util.StringArray{}
Expand Down Expand Up @@ -80,6 +80,7 @@ var (
func init() {
flagSet.Var(&lookupdTCPAddrs, "lookupd-tcp-address", "lookupd TCP address (may be given multiple times)")
flagSet.Var(&e2eProcessingLatencyPercentiles, "e2e-processing-latency-percentile", "message processing time percentiles to keep track of (can be specified multiple times or comma separated, default none)")
flagSet.Var(&authHttpAddresses, "auth-http-address", "<addr>:<port> to query auth server (may be given multiple times)")
}

func main() {
Expand Down
26 changes: 10 additions & 16 deletions nsqd/client_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"sync/atomic"
"time"

"github.com/bitly/nsq/util/auth"
"github.com/mreiferson/go-snappystream"
)

Expand Down Expand Up @@ -107,7 +108,7 @@ type clientV2 struct {
lenSlice []byte

AuthSecret string
AuthState *AuthState
AuthState *auth.AuthState
}

func newClientV2(id int64, conn net.Conn, context *context) *clientV2 {
Expand Down Expand Up @@ -514,18 +515,13 @@ func (c *clientV2) QueryAuthd() error {
tlsEnabled = "true"
}

// for each auth server, try to authorize. on success return authorizations, on failure try next auth server.
for _, authd := range c.context.nsqd.options.AuthHTTPAddresses {
authState, err := queryAuthd(authd, remoteIp, tlsEnabled, c.AuthSecret)
if err != nil {
log.Printf("Error: failed auth against %s %s", authd, err)
continue
}

c.AuthState = authState
return nil
authState, err := auth.QueryAnyAuthd(c.context.nsqd.options.AuthHTTPAddresses,
remoteIp, tlsEnabled, c.AuthSecret)
if err != nil {
return err
}
return errors.New("Unable to access auth server")
c.AuthState = authState
return nil
}

func (c *clientV2) Auth(secret string) error {
Expand All @@ -543,10 +539,8 @@ func (c *clientV2) IsAuthorized(topic, channel string) (bool, error) {
return false, err
}
}
for _, a := range c.AuthState.Authorizations {
if a.IsAllowed(topic, channel) {
return true, nil
}
if c.AuthState.IsAllowed(topic, channel) {
return true, nil
}
return false, nil
}
Expand Down
4 changes: 2 additions & 2 deletions nsqd/protocol_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func sub(t *testing.T, conn io.ReadWriter, topicName string, channelName string)
readValidate(t, conn, frameTypeResponse, "OK")
}

func auth(t *testing.T, conn io.ReadWriter, authSecret string, expectSuccess string) {
func authCmd(t *testing.T, conn io.ReadWriter, authSecret string, expectSuccess string) {
auth := &nsq.Command{[]byte("AUTH"), nil, []byte(authSecret)}
_, err := auth.WriteTo(conn)
assert.Equal(t, err, nil)
Expand Down Expand Up @@ -1326,7 +1326,7 @@ func runAuthTest(t *testing.T, authResponse, authSecret, authError, authSuccess
"tls_v1": false,
}, nsq.FrameTypeResponse)

auth(t, conn, authSecret, authSuccess)
authCmd(t, conn, authSecret, authSuccess)
if authError != "" {
readValidate(t, conn, nsq.FrameTypeError, authError)
} else {
Expand Down
2 changes: 1 addition & 1 deletion test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ apps/nsqlookupd/nsqlookupd >$LOOKUP_LOGFILE 2>&1 &
LOOKUPD_PID=$!

# build and run nsqd configured to use our lookupd above
NSQD_LOGFILE=$(mktemp -t nsqlookupd.XXXXXXX)
NSQD_LOGFILE=$(mktemp -t nsqd.XXXXXXX)
cmd="apps/nsqd/nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=nsqd/test/certs/cert.pem --tls-key=nsqd/test/certs/key.pem"
echo "building and starting $cmd"
echo " logging to $NSQD_LOGFILE"
Expand Down
27 changes: 25 additions & 2 deletions nsqd/authorizations.go → util/auth/authorizations.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package nsqd
package auth

import (
"errors"
"fmt"
"log"
"net/url"
"regexp"
"time"
Expand Down Expand Up @@ -58,14 +60,35 @@ func (a *Authorization) IsAllowed(topic, channel string) bool {
return false
}

func (a *AuthState) IsAllowed(topic, channel string) bool {
for _, aa := range a.Authorizations {
if aa.IsAllowed(topic, channel) {
return true
}
}
return false
}

func (a *AuthState) IsExpired() bool {
if a.Expires.Before(time.Now()) {
return true
}
return false
}

func queryAuthd(authd, remoteIp, tlsEnabled, authSecret string) (*AuthState, error) {
func QueryAnyAuthd(authd []string, remoteIp, tlsEnabled, authSecret string) (*AuthState, error) {
for _, a := range authd {
authState, err := QueryAuthd(a, remoteIp, tlsEnabled, authSecret)
if err != nil {
log.Printf("Error: failed auth against %s %s", a, err)
continue
}
return authState, nil
}
return nil, errors.New("Unable to access auth server")
}

func QueryAuthd(authd, remoteIp, tlsEnabled, authSecret string) (*AuthState, error) {

v := url.Values{}
v.Set("remote_ip", remoteIp)
Expand Down