Skip to content

Commit

Permalink
Merge pull request #54 from jehiah/tls_configs_54
Browse files Browse the repository at this point in the history
Improve TLS support
  • Loading branch information
mreiferson committed Jun 26, 2014
2 parents 71f6f6e + 77f8b9b commit f4ae369
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 7 deletions.
64 changes: 63 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package nsq

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"reflect"
Expand All @@ -13,6 +15,11 @@ import (
"unsafe"
)

type configHandler interface {
Handles(option string) bool
Set(c *Config, option string, value interface{}) error
}

// Config is a struct of NSQ options
//
// The only valid way to create a Config is via NewConfig, using a struct literal will panic.
Expand All @@ -21,7 +28,8 @@ import (
//
// Use Set(key string, value interface{}) as an alternate way to set parameters
type Config struct {
initialized bool
initialized bool
configHandlers []configHandler

// Deadlines for network reads and writes
ReadTimeout time.Duration `opt:"read_timeout" min:"100ms" max:"5m" default:"60s"`
Expand Down Expand Up @@ -57,6 +65,7 @@ type Config struct {
SampleRate int32 `opt:"sample_rate" min:"0" max:"99"`

// TLS Settings
// use tls-root-ca-file and tls-insecure-skip-verify to set tls config options
TlsV1 bool `opt:"tls_v1"`
TlsConfig *tls.Config `opt:"tls_config"`

Expand Down Expand Up @@ -89,6 +98,7 @@ type Config struct {
// This must be used to initialize Config structs. Values can be set directly, or through Config.Set()
func NewConfig() *Config {
c := &Config{}
c.configHandlers = append(c.configHandlers, &tlsHandler{})
c.initialized = true
if err := c.setDefaults(); err != nil {
panic(err.Error())
Expand Down Expand Up @@ -119,6 +129,12 @@ func (c *Config) Set(option string, value interface{}) error {

c.assertInitialized()

for _, h := range c.configHandlers {
if h.Handles(option) {
return h.Set(c, option, value)
}
}

val := reflect.ValueOf(c).Elem()
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
Expand Down Expand Up @@ -234,6 +250,52 @@ func (c *Config) setDefaults() error {
return nil
}

type tlsHandler struct {
}

func (t *tlsHandler) Handles(option string) bool {
switch option {
case "tls-root-ca-file", "tls-insecure-skip-verify":
return true
}
return false
}
func (t *tlsHandler) Set(c *Config, option string, value interface{}) error {
if c.TlsConfig == nil {
c.TlsConfig = &tls.Config{}
}
val := reflect.ValueOf(c.TlsConfig).Elem()

switch option {
case "tls-root-ca-file":
filename, ok := value.(string)
if !ok {
return fmt.Errorf("ERROR: %v is not a string", value)
}
tlsCertPool := x509.NewCertPool()
ca_cert_file, err := ioutil.ReadFile(filename)
if err != nil {
return fmt.Errorf("ERROR: failed to read custom Certificate Authority file %s", err)
}
if !tlsCertPool.AppendCertsFromPEM(ca_cert_file) {
return fmt.Errorf("ERROR: failed to append certificates from Certificate Authority file")
}
c.TlsConfig.ClientCAs = tlsCertPool
return nil
case "tls-insecure-skip-verify":
fieldVal := val.FieldByName("InsecureSkipVerify")
dest := unsafeValueOf(fieldVal)
coercedVal, err := coerce(value, fieldVal.Type())
if err != nil {
return fmt.Errorf("failed to coerce option %s (%v) - %s",
option, value, err)
}
dest.Set(coercedVal)
return nil
}
return fmt.Errorf("unknown option %s", option)
}

// because Config contains private structs we can't use reflect.Value
// directly, instead we need to "unsafely" address the variable
func unsafeValueOf(val reflect.Value) reflect.Value {
Expand Down
9 changes: 8 additions & 1 deletion config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ func TestConfigSet(t *testing.T) {
t.Error("No error when setting `tls_v1` to an invalid value")
}
if err := c.Set("tls_v1", true); err != nil {
t.Errorf("Error setting `tls_v1` config: %v", err)
t.Errorf("Error setting `tls_v1` config. %v", err)
}

if err := c.Set("tls-insecure-skip-verify", true); err != nil {
t.Errorf("Error setting `tls-insecure-skip-verify` config. %v", err)
}
if c.TlsConfig.InsecureSkipVerify != true {
t.Errorf("Error setting `tls-insecure-skip-verify` config: %v", c.TlsConfig)
}
}

Expand Down
19 changes: 15 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn {
return &Conn{
addr: addr,

config: config,
config: config,
delegate: delegate,

maxRdyCount: 2500,
Expand Down Expand Up @@ -348,9 +348,20 @@ func (c *Conn) identify() (*IdentifyResponse, error) {
return resp, nil
}

func (c *Conn) upgradeTLS(conf *tls.Config) error {
c.tlsConn = tls.Client(c.conn, conf)
err := c.tlsConn.Handshake()
func (c *Conn) upgradeTLS(tlsConf *tls.Config) error {
// create a local copy of the config to set ServerName for this connection
var conf tls.Config
if tlsConf != nil {
conf = *tlsConf
}
host, _, err := net.SplitHostPort(c.addr)
if err != nil {
return err
}
conf.ServerName = host

c.tlsConn = tls.Client(c.conn, &conf)
err = c.tlsConn.Handshake()
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (r *Consumer) getMaxInFlight() int32 {
// will allow in-flight, and updates all existing connections as appropriate.
//
// For example, ChangeMaxInFlight(0) would pause message flow
//
//
// If already connected, it updates the reader RDY state for each connection.
func (r *Consumer) ChangeMaxInFlight(maxInFlight int) {
if r.getMaxInFlight() == int32(maxInFlight) {
Expand Down

0 comments on commit f4ae369

Please sign in to comment.