diff --git a/storage/config.go b/storage/config.go index c1c1b95..3c3745a 100644 --- a/storage/config.go +++ b/storage/config.go @@ -52,9 +52,15 @@ type Config struct { // inputs is accelerated. The size of the cache is large so understand how to configure the cache. // TODO provide some default config choices // If this section is skipped, Caching is disabled - Cache CachingConfig `json:"cache"` - Limits LimitsConfig `json:"limits" pflag:",Sets limits for stores."` - DefaultHTTPClientHeaders map[string][]string `json:"defaultHttpClientHeaders" pflag:"-,Sets http headers to set on the default http client."` + Cache CachingConfig `json:"cache"` + Limits LimitsConfig `json:"limits" pflag:",Sets limits for stores."` + DefaultHTTPClient *HTTPClientConfig `json:"defaultHttpClient" pflag:",Sets the default http client config."` +} + +// HTTPClientConfig encapsulates common settings that can be applied to an HTTP Client. +type HTTPClientConfig struct { + Headers map[string][]string `json:"headers" pflag:"-,Sets http headers to set on the http client."` + Timeout config.Duration `json:"timeout" pflag:"timeout,Sets time out on the http client."` } // Defines connection configurations. diff --git a/storage/config_flags.go b/storage/config_flags.go index 50f634e..91dcf4d 100755 --- a/storage/config_flags.go +++ b/storage/config_flags.go @@ -52,5 +52,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.max_size_mbs"), defaultConfig.Cache.MaxSizeMegabytes, "Maximum size of the cache where the Blob store data is cached in-memory. If not specified or set to 0, cache is not used") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "cache.target_gc_percent"), defaultConfig.Cache.TargetGCPercent, "Sets the garbage collection target percentage.") cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "limits.maxDownloadMBs"), defaultConfig.Limits.GetLimitMegabytes, "Maximum allowed download size (in MBs) per call.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultHttpClient.timeout"), defaultConfig.DefaultHTTPClient.Timeout.String(), "Sets time out on the http client.") return cmdFlags } diff --git a/storage/config_flags_test.go b/storage/config_flags_test.go index 4809b6b..a47ea88 100755 --- a/storage/config_flags_test.go +++ b/storage/config_flags_test.go @@ -341,4 +341,26 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_defaultHttpClient.timeout", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("defaultHttpClient.timeout"); err == nil { + assert.Equal(t, string(defaultConfig.DefaultHTTPClient.Timeout.String()), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.DefaultHTTPClient.Timeout.String() + + cmdFlags.Set("defaultHttpClient.timeout", testValue) + if vString, err := cmdFlags.GetString("defaultHttpClient.timeout"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DefaultHTTPClient.Timeout) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/storage/rawstores.go b/storage/rawstores.go index 5f17457..d65107a 100644 --- a/storage/rawstores.go +++ b/storage/rawstores.go @@ -39,11 +39,18 @@ func applyDefaultHeaders(r *http.Request, headers map[string][]string) { } } -func createHTTPClientWithDefaultHeaders(headers map[string][]string) *http.Client { - c := &http.Client{} +func createHTTPClient(cfg *HTTPClientConfig) *http.Client { + if cfg == nil { + return &http.Client{} + } + + c := &http.Client{ + Timeout: cfg.Timeout.Duration, + } + c.Transport = &proxyTransport{ RoundTripper: http.DefaultTransport, - defaultHeaders: headers, + defaultHeaders: cfg.Headers, } return c @@ -51,16 +58,13 @@ func createHTTPClientWithDefaultHeaders(headers map[string][]string) *http.Clien // Creates a new Data Store with the supplied config. func NewDataStore(cfg *Config, metricsScope promutils.Scope) (s *DataStore, err error) { - // HACK: This sets http headers to the default http client. This is because - // some underlying stores (e.g. S3 Stow Store) grabs the default http client - // and doesn't allow configuration of default headers. - if len(cfg.DefaultHTTPClientHeaders) > 0 { + if cfg.DefaultHTTPClient != nil { defaultClient := http.DefaultClient defer func() { http.DefaultClient = defaultClient }() - http.DefaultClient = createHTTPClientWithDefaultHeaders(cfg.DefaultHTTPClientHeaders) + http.DefaultClient = createHTTPClient(cfg.DefaultHTTPClient) } var rawStore RawStore diff --git a/storage/rawstores_test.go b/storage/rawstores_test.go index 5968098..b570f08 100644 --- a/storage/rawstores_test.go +++ b/storage/rawstores_test.go @@ -3,29 +3,49 @@ package storage import ( "net/http" "testing" + "time" + + "github.com/lyft/flytestdlib/config" "github.com/stretchr/testify/assert" ) -func Test_createHttpClientWithDefaultHeaders(t *testing.T) { +func Test_createHTTPClient(t *testing.T) { t.Run("nil", func(t *testing.T) { - client := createHTTPClientWithDefaultHeaders(nil) - assert.NotNil(t, client.Transport) - proxyTransport, casted := client.Transport.(*proxyTransport) - assert.True(t, casted) - assert.Nil(t, proxyTransport.defaultHeaders) + client := createHTTPClient(nil) + assert.Nil(t, client.Transport) }) t.Run("Some headers", func(t *testing.T) { m := map[string][]string{ "Header1": {"val1", "val2"}, } - client := createHTTPClientWithDefaultHeaders(m) + + client := createHTTPClient(&HTTPClientConfig{ + Headers: m, + }) + assert.NotNil(t, client.Transport) proxyTransport, casted := client.Transport.(*proxyTransport) assert.True(t, casted) assert.Equal(t, m, proxyTransport.defaultHeaders) }) + + t.Run("Set empty timeout", func(t *testing.T) { + client := createHTTPClient(&HTTPClientConfig{ + Timeout: config.Duration{}, + }) + + assert.Zero(t, client.Timeout) + }) + + t.Run("Set timeout", func(t *testing.T) { + client := createHTTPClient(&HTTPClientConfig{ + Timeout: config.Duration{Duration: 2 * time.Second}, + }) + + assert.Equal(t, 2*time.Second, client.Timeout) + }) } func Test_applyDefaultHeaders(t *testing.T) {