From 8d6ca3cde1902c443da02e2fa95335b47c64b240 Mon Sep 17 00:00:00 2001 From: Monis Khan Date: Wed, 24 Aug 2022 16:04:19 +0000 Subject: [PATCH] exec auth: support TLS config caching This change updates the transport.Config .Dial and .TLS.GetCert fields to use a struct wrapper. This indirection via a pointer allows the functions to be compared and thus makes them valid to use as map keys. This change is then leveraged by the existing global exec auth and TLS config caches to return the same authenticator and TLS config even when distinct but identical rest configs were used to create distinct clientsets. Signed-off-by: Monis Khan Kubernetes-commit: e3bffcd28922b24e54cc89c4356c2dc16e778e67 --- plugin/pkg/client/auth/exec/exec.go | 30 ++-- .../pkg/client/auth/exec/exec_cache_test.go | 106 +++++++++++++ transport/cache.go | 25 ++- transport/cache_test.go | 16 ++ transport/config.go | 21 ++- transport/transport.go | 25 +++ transport/transport_test.go | 146 ++++++++++++++++++ 7 files changed, 354 insertions(+), 15 deletions(-) create mode 100644 plugin/pkg/client/auth/exec/exec_cache_test.go diff --git a/plugin/pkg/client/auth/exec/exec.go b/plugin/pkg/client/auth/exec/exec.go index d37dfbf732..73876f6887 100644 --- a/plugin/pkg/client/auth/exec/exec.go +++ b/plugin/pkg/client/auth/exec/exec.go @@ -199,14 +199,18 @@ func newAuthenticator(c *cache, isTerminalFunc func(int) bool, config *api.ExecC now: time.Now, environ: os.Environ, - defaultDialer: defaultDialer, - connTracker: connTracker, + connTracker: connTracker, } for _, env := range config.Env { a.env = append(a.env, env.Name+"="+env.Value) } + // these functions are made comparable and stored in the cache so that repeated clientset + // construction with the same rest.Config results in a single TLS cache and Authenticator + a.getCert = &transport.GetCertHolder{GetCert: a.cert} + a.dial = &transport.DialHolder{Dial: defaultDialer.DialContext} + return c.put(key, a), nil } @@ -261,8 +265,6 @@ type Authenticator struct { now func() time.Time environ func() []string - // defaultDialer is used for clients which don't specify a custom dialer - defaultDialer *connrotation.Dialer // connTracker tracks all connections opened that we need to close when rotating a client certificate connTracker *connrotation.ConnectionTracker @@ -273,6 +275,12 @@ type Authenticator struct { mu sync.Mutex cachedCreds *credentials exp time.Time + + // getCert makes Authenticator.cert comparable to support TLS config caching + getCert *transport.GetCertHolder + // dial is used for clients which do not specify a custom dialer + // it is comparable to support TLS config caching + dial *transport.DialHolder } type credentials struct { @@ -300,18 +308,20 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error { if c.HasCertCallback() { return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set") } - c.TLS.GetCert = a.cert + c.TLS.GetCert = a.getCert.GetCert + c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching - var d *connrotation.Dialer if c.Dial != nil { // if c has a custom dialer, we have to wrap it - d = connrotation.NewDialerWithTracker(c.Dial, a.connTracker) + // TLS config caching is not supported for this config + d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker) + c.Dial = d.DialContext + c.DialHolder = nil } else { - d = a.defaultDialer + c.Dial = a.dial.Dial + c.DialHolder = a.dial // comparable for TLS config caching } - c.Dial = d.DialContext - return nil } diff --git a/plugin/pkg/client/auth/exec/exec_cache_test.go b/plugin/pkg/client/auth/exec/exec_cache_test.go new file mode 100644 index 0000000000..ecf84262a1 --- /dev/null +++ b/plugin/pkg/client/auth/exec/exec_cache_test.go @@ -0,0 +1,106 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package exec_test // separate package to prevent circular import + +import ( + "context" + "testing" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + utilnet "k8s.io/apimachinery/pkg/util/net" + clientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +// TestExecTLSCache asserts the semantics of the TLS cache when exec auth is used. +// +// In particular, when: +// - multiple identical rest configs exist as distinct objects, and +// - these rest configs use exec auth, and +// - these rest configs are used to create distinct clientsets, then +// +// the underlying TLS config is shared between those clientsets. +func TestExecTLSCache(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + t.Cleanup(cancel) + + config1 := &rest.Config{ + Host: "https://localhost", + ExecProvider: &clientcmdapi.ExecConfig{ + Command: "./testdata/test-plugin.sh", + APIVersion: "client.authentication.k8s.io/v1", + InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode, + }, + } + client1 := clientset.NewForConfigOrDie(config1) + + config2 := &rest.Config{ + Host: "https://localhost", + ExecProvider: &clientcmdapi.ExecConfig{ + Command: "./testdata/test-plugin.sh", + APIVersion: "client.authentication.k8s.io/v1", + InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode, + }, + } + client2 := clientset.NewForConfigOrDie(config2) + + config3 := &rest.Config{ + Host: "https://localhost", + ExecProvider: &clientcmdapi.ExecConfig{ + Command: "./testdata/test-plugin.sh", + Args: []string{"make this exec auth different"}, + APIVersion: "client.authentication.k8s.io/v1", + InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode, + }, + } + client3 := clientset.NewForConfigOrDie(config3) + + _, _ = client1.CoreV1().Nodes().List(ctx, metav1.ListOptions{}) + _, _ = client2.CoreV1().Namespaces().List(ctx, metav1.ListOptions{}) + _, _ = client3.CoreV1().PersistentVolumes().List(ctx, metav1.ListOptions{}) + + rt1 := client1.RESTClient().(*rest.RESTClient).Client.Transport + rt2 := client2.RESTClient().(*rest.RESTClient).Client.Transport + rt3 := client3.RESTClient().(*rest.RESTClient).Client.Transport + + tlsConfig1, err := utilnet.TLSClientConfig(rt1) + if err != nil { + t.Fatal(err) + } + tlsConfig2, err := utilnet.TLSClientConfig(rt2) + if err != nil { + t.Fatal(err) + } + tlsConfig3, err := utilnet.TLSClientConfig(rt3) + if err != nil { + t.Fatal(err) + } + + if tlsConfig1 == nil || tlsConfig2 == nil || tlsConfig3 == nil { + t.Fatal("expected non-nil TLS configs") + } + + if tlsConfig1 != tlsConfig2 { + t.Fatal("expected the same TLS config for matching exec config via rest config") + } + + if tlsConfig1 == tlsConfig3 { + t.Fatal("expected different TLS config for non-matching exec config via rest config") + } +} diff --git a/transport/cache.go b/transport/cache.go index 5fe768ed5e..f4a864d053 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -17,6 +17,7 @@ limitations under the License. package transport import ( + "context" "fmt" "net" "net/http" @@ -50,6 +51,9 @@ type tlsCacheKey struct { serverName string nextProtos string disableCompression bool + // these functions are wrapped to allow them to be used as map keys + getCert *GetCertHolder + dial *DialHolder } func (t tlsCacheKey) String() string { @@ -57,7 +61,8 @@ func (t tlsCacheKey) String() string { if len(t.keyData) > 0 { keyText = "" } - return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression) + return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p", + t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial) } func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { @@ -87,8 +92,10 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { return http.DefaultTransport, nil } - dial := config.Dial - if dial == nil { + var dial func(ctx context.Context, network, address string) (net.Conn, error) + if config.Dial != nil { + dial = config.Dial + } else { dial = (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, @@ -133,10 +140,18 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { return tlsCacheKey{}, false, err } - if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil { + if c.Proxy != nil { // cannot determine equality for functions return tlsCacheKey{}, false, nil } + if c.Dial != nil && c.DialHolder == nil { + // cannot determine equality for dial function that doesn't have non-nil DialHolder set as well + return tlsCacheKey{}, false, nil + } + if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil { + // cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well + return tlsCacheKey{}, false, nil + } k := tlsCacheKey{ insecure: c.TLS.Insecure, @@ -144,6 +159,8 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { serverName: c.TLS.ServerName, nextProtos: strings.Join(c.TLS.NextProtos, ","), disableCompression: c.DisableCompression, + getCert: c.TLS.GetCertHolder, + dial: c.DialHolder, } if c.TLS.ReloadTLSFiles { diff --git a/transport/cache_test.go b/transport/cache_test.go index c6d06fcab3..87d070bb01 100644 --- a/transport/cache_test.go +++ b/transport/cache_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "net" "net/http" + "net/url" "testing" ) @@ -58,16 +59,24 @@ func TestTLSConfigKey(t *testing.T) { t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) continue } + if keyA != (tlsCacheKey{}) { + t.Errorf("Expected empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) + continue + } } } // Make sure config fields that affect the tls config affect the cache key dialer := net.Dialer{} getCert := func() (*tls.Certificate, error) { return nil, nil } + getCertHolder := &GetCertHolder{GetCert: getCert} uniqueConfigurations := map[string]*Config{ + "proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }}, "no tls": {}, "dialer": {Dial: dialer.DialContext}, "dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}, + "dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}}, + "dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}}, "insecure": {TLS: TLSConfig{Insecure: true}}, "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, @@ -128,6 +137,13 @@ func TestTLSConfigKey(t *testing.T) { GetCert: func() (*tls.Certificate, error) { return nil, nil }, }, }, + "getCert3": { + TLS: TLSConfig{ + KeyData: []byte{1}, + GetCert: getCert, + GetCertHolder: getCertHolder, + }, + }, "getCert1, key 2": { TLS: TLSConfig{ KeyData: []byte{2}, diff --git a/transport/config.go b/transport/config.go index 89de798f60..fd853c0b39 100644 --- a/transport/config.go +++ b/transport/config.go @@ -68,7 +68,11 @@ type Config struct { WrapTransport WrapperFunc // Dial specifies the dial function for creating unencrypted TCP connections. + // If specified, this transport will be non-cacheable unless DialHolder is also set. Dial func(ctx context.Context, network, address string) (net.Conn, error) + // DialHolder can be populated to make transport configs cacheable. + // If specified, DialHolder.Dial must be equal to Dial. + DialHolder *DialHolder // Proxy is the proxy func to be used for all requests made by this // transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy @@ -78,6 +82,11 @@ type Config struct { Proxy func(*http.Request) (*url.URL, error) } +// DialHolder is used to make the wrapped function comparable so that it can be used as a map key. +type DialHolder struct { + Dial func(ctx context.Context, network, address string) (net.Conn, error) +} + // ImpersonationConfig has all the available impersonation options type ImpersonationConfig struct { // UserName matches user.Info.GetName() @@ -143,5 +152,15 @@ type TLSConfig struct { // To use only http/1.1, set to ["http/1.1"]. NextProtos []string - GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field. + // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field. + // If specified, this transport is non-cacheable unless CertHolder is populated. + GetCert func() (*tls.Certificate, error) + // CertHolder can be populated to make transport configs that set GetCert cacheable. + // If set, CertHolder.GetCert must be equal to GetCert. + GetCertHolder *GetCertHolder +} + +// GetCertHolder is used to make the wrapped function comparable so that it can be used as a map key. +type GetCertHolder struct { + GetCert func() (*tls.Certificate, error) } diff --git a/transport/transport.go b/transport/transport.go index b4a7bfa67c..eabfce72d0 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -24,6 +24,7 @@ import ( "fmt" "io/ioutil" "net/http" + "reflect" "sync" "time" @@ -39,6 +40,10 @@ func New(config *Config) (http.RoundTripper, error) { return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed") } + if !isValidHolders(config) { + return nil, fmt.Errorf("misconfigured holder for dialer or cert callback") + } + var ( rt http.RoundTripper err error @@ -56,6 +61,26 @@ func New(config *Config) (http.RoundTripper, error) { return HTTPWrappersForConfig(config, rt) } +func isValidHolders(config *Config) bool { + if config.TLS.GetCertHolder != nil { + if config.TLS.GetCertHolder.GetCert == nil || + config.TLS.GetCert == nil || + reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() { + return false + } + } + + if config.DialHolder != nil { + if config.DialHolder.Dial == nil || + config.Dial == nil || + reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() { + return false + } + } + + return true +} + // TLSConfigFor returns a tls.Config that will provide the transport level security defined // by the provided Config. Will return nil if no transport level security is requested. func TLSConfigFor(c *Config) (*tls.Config, error) { diff --git a/transport/transport_test.go b/transport/transport_test.go index c439c96f81..e0fd2679a5 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "testing" ) @@ -94,6 +95,13 @@ stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa ) func TestNew(t *testing.T) { + globalGetCert := &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + } + globalDial := &DialHolder{ + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + } + testCases := map[string]struct { Config *Config Err bool @@ -255,6 +263,144 @@ func TestNew(t *testing.T) { }, }, }, + "nil holders and nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: nil, + DialHolder: nil, + }, + Err: false, + TLS: false, + TLSCert: false, + TLSErr: false, + Default: true, + Insecure: false, + DefaultRoots: false, + }, + "nil holders and non-nil regular get cert": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + GetCertHolder: nil, + }, + Dial: nil, + DialHolder: nil, + }, + Err: false, + TLS: true, + TLSCert: true, + TLSErr: false, + Default: false, + Insecure: false, + DefaultRoots: true, + }, + "nil holders and non-nil regular dial": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + DialHolder: nil, + }, + Err: false, + TLS: true, + TLSCert: false, + TLSErr: false, + Default: false, + Insecure: false, + DefaultRoots: true, + }, + "non-nil dial holder and nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: nil, + DialHolder: &DialHolder{}, + }, + Err: true, + }, + "non-nil cert holder and nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: &GetCertHolder{}, + }, + Dial: nil, + DialHolder: nil, + }, + Err: true, + }, + "non-nil dial holder and non-nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + DialHolder: &DialHolder{}, + }, + Err: true, + }, + "non-nil cert holder and non-nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + GetCertHolder: &GetCertHolder{}, + }, + Dial: nil, + DialHolder: nil, + }, + Err: true, + }, + "non-nil dial holder+internal and non-nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + DialHolder: &DialHolder{ + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + }, + }, + Err: true, + }, + "non-nil cert holder+internal and non-nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + GetCertHolder: &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + }, + }, + Dial: nil, + DialHolder: nil, + }, + Err: true, + }, + "non-nil holders+internal and non-nil regular with correct address": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: globalGetCert.GetCert, + GetCertHolder: globalGetCert, + }, + Dial: globalDial.Dial, + DialHolder: globalDial, + }, + Err: false, + TLS: true, + TLSCert: true, + TLSErr: false, + Default: false, + Insecure: false, + DefaultRoots: true, + }, } for k, testCase := range testCases { t.Run(k, func(t *testing.T) {