diff --git a/example_test.go b/example_test.go index 9651785b8..f8930dbaa 100644 --- a/example_test.go +++ b/example_test.go @@ -17,6 +17,7 @@ import ( "context" "fmt" "log" + "net" "time" "github.com/nats-io/nats.go" @@ -44,6 +45,40 @@ func ExampleConnect() { nc.Close() } +type skipTLSDialer struct { + dialer *net.Dialer + skipTLS bool +} + +func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) { + return sd.dialer.Dial(network, address) +} + +func (sd *skipTLSDialer) SkipTLSHandshake() bool { + return sd.skipTLS +} + +func ExampleCustomDialer() { + // Given the following CustomDialer implementation: + // + // type skipTLSDialer struct { + // dialer *net.Dialer + // skipTLS bool + // } + // + // func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) { + // return sd.dialer.Dial(network, address) + // } + // + // func (sd *skipTLSDialer) SkipTLSHandshake() bool { + // return true + // } + // + sd := &skipTLSDialer{dialer: &net.Dialer{Timeout: 2 * time.Second}, skipTLS: true} + nc, _ := nats.Connect("demo.nats.io", nats.SetCustomDialer(sd)) + defer nc.Close() +} + // This Example shows an asynchronous subscriber. func ExampleConn_Subscribe() { nc, _ := nats.Connect(nats.DefaultURL) diff --git a/nats.go b/nats.go index d608d924b..b1cd48fa1 100644 --- a/nats.go +++ b/nats.go @@ -247,8 +247,9 @@ type asyncCallbacksHandler struct { // Option is a function on the options for a connection. type Option func(*Options) error -// CustomDialer can be used to specify any dialer, not necessarily -// a *net.Dialer. +// CustomDialer can be used to specify any dialer, not necessarily a +// *net.Dialer. A CustomDialer may also implement `SkipTLSHandshake() bool` +// in order to skip the TLS handshake in case not required. type CustomDialer interface { Dial(network, address string) (net.Conn, error) } @@ -303,10 +304,6 @@ type Options struct { // transports. TLSConfig *tls.Config - // SkipTLSWrapper does not upgrade the connection to TLS and is - // meant to be used if the custom dialer does handle TLS itself - SkipTLSWrapper bool - // AllowReconnect enables reconnection logic to be used when we // encounter a disconnect from the current server. AllowReconnect bool @@ -1189,16 +1186,6 @@ func SetCustomDialer(dialer CustomDialer) Option { } } -// SetSkipTLSWrapper is an Option to be used with the CustomDialer which -// will not wrap the connection with TLS. Use it if the CustomDialer did -// already handle TLS -func SetSkipTLSWrapper(skip bool) Option { - return func(o *Options) error { - o.SkipTLSWrapper = skip - return nil - } -} - // UseOldRequestStyle is an Option to force usage of the old Request style. func UseOldRequestStyle() Option { return func(o *Options) error { @@ -1906,11 +1893,18 @@ func (nc *Conn) createConn() (err error) { return nil } +type skipTLSDialer interface { + SkipTLSHandshake() bool +} + // makeTLSConn will wrap an existing Conn using TLS func (nc *Conn) makeTLSConn() error { - if nc.Opts.SkipTLSWrapper { + if nc.Opts.CustomDialer != nil { // we do nothing when asked to skip the TLS wrapper - return nil + sd, ok := nc.Opts.CustomDialer.(skipTLSDialer) + if ok && sd.SkipTLSHandshake() { + return nil + } } // Allow the user to configure their own tls.Config structure. var tlsCopy *tls.Config diff --git a/services/service.go b/services/service.go index 0e4b61508..e9e3a652e 100644 --- a/services/service.go +++ b/services/service.go @@ -31,7 +31,7 @@ import ( type ( - // Service is an interface for sevice management. + // Service is an interface for service management. // It exposes methods to stop/reset a service, as well as get information on a service. Service interface { ID() string diff --git a/ws_test.go b/ws_test.go index 3b8a37442..eafe7b67f 100644 --- a/ws_test.go +++ b/ws_test.go @@ -868,6 +868,60 @@ func TestWSWithTLS(t *testing.T) { } } +type testSkipTLSDialer struct { + dialer *net.Dialer + skipTLS bool +} + +func (sd *testSkipTLSDialer) Dial(network, address string) (net.Conn, error) { + return sd.dialer.Dial(network, address) +} + +func (sd *testSkipTLSDialer) SkipTLSHandshake() bool { + return sd.skipTLS +} + +func TestWSWithTLSCustomDialer(t *testing.T) { + sopts := testWSGetDefaultOptions(t, true) + s := RunServerWithOptions(sopts) + defer s.Shutdown() + + sd := &testSkipTLSDialer{ + dialer: &net.Dialer{ + Timeout: 2 * time.Second, + }, + skipTLS: true, + } + + // Connect with CustomDialer that fails since TLSHandshake is disabled. + copts := make([]Option, 0) + copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true})) + copts = append(copts, SetCustomDialer(sd)) + _, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...) + if err == nil { + t.Fatalf("Expected error on connect: %v", err) + } + if err.Error() != `invalid websocket connection` { + t.Logf("Expected invalid websocket connection: %v", err) + } + + // Retry with the dialer. + copts = make([]Option, 0) + sd = &testSkipTLSDialer{ + dialer: &net.Dialer{ + Timeout: 2 * time.Second, + }, + skipTLS: false, + } + copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true})) + copts = append(copts, SetCustomDialer(sd)) + nc, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...) + if err != nil { + t.Fatalf("Unexpected error on connect: %v", err) + } + defer nc.Close() +} + func TestWSTlsNoConfig(t *testing.T) { opts := GetDefaultOptions() opts.Servers = []string{"wss://localhost:443"}