Skip to content

Commit

Permalink
TLS config: Enable selection of min TLS version (#375)
Browse files Browse the repository at this point in the history
* TLS config: Enable selection of min TLS version

go1.18 changes the default minimum TLS version to 1.2.

Let's make the default minimum version configurable, while following go
default.

The allowed values (TLS10, ..) come from the exporter-toolkit:
https://github.com/prometheus/exporter-toolkit/blob/master/docs/web-configuration.md

TLSVersion is exported so the exporter toolkit can reuse them later.

Signed-off-by: Julien Pivotto <[email protected]>
  • Loading branch information
roidelapluie authored Apr 19, 2022
1 parent 0c7319a commit 3763a1d
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 26 deletions.
101 changes: 81 additions & 20 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,87 @@ import (
"gopkg.in/yaml.v2"
)

// DefaultHTTPClientConfig is the default HTTP client configuration.
var DefaultHTTPClientConfig = HTTPClientConfig{
FollowRedirects: true,
EnableHTTP2: true,
}
var (
// DefaultHTTPClientConfig is the default HTTP client configuration.
DefaultHTTPClientConfig = HTTPClientConfig{
FollowRedirects: true,
EnableHTTP2: true,
}

// defaultHTTPClientOptions holds the default HTTP client options.
var defaultHTTPClientOptions = httpClientOptions{
keepAlivesEnabled: true,
http2Enabled: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
idleConnTimeout: 5 * time.Minute,
}
// defaultHTTPClientOptions holds the default HTTP client options.
defaultHTTPClientOptions = httpClientOptions{
keepAlivesEnabled: true,
http2Enabled: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
idleConnTimeout: 5 * time.Minute,
}
)

type closeIdler interface {
CloseIdleConnections()
}

type TLSVersion uint16

var TLSVersions = map[string]TLSVersion{
"TLS13": (TLSVersion)(tls.VersionTLS13),
"TLS12": (TLSVersion)(tls.VersionTLS12),
"TLS11": (TLSVersion)(tls.VersionTLS11),
"TLS10": (TLSVersion)(tls.VersionTLS10),
}

func (tv *TLSVersion) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
err := unmarshal((*string)(&s))
if err != nil {
return err
}
if v, ok := TLSVersions[s]; ok {
*tv = v
return nil
}
return fmt.Errorf("unknown TLS version: %s", s)
}

func (tv *TLSVersion) MarshalYAML() (interface{}, error) {
if tv != nil || *tv == 0 {
return []byte("null"), nil
}
for s, v := range TLSVersions {
if *tv == v {
return s, nil
}
}
return nil, fmt.Errorf("unknown TLS version: %d", tv)
}

// MarshalJSON implements the json.Unmarshaler interface for TLSVersion.
func (tv *TLSVersion) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
if v, ok := TLSVersions[s]; ok {
*tv = v
return nil
}
return fmt.Errorf("unknown TLS version: %s", s)
}

// MarshalJSON implements the json.Marshaler interface for TLSVersion.
func (tv *TLSVersion) MarshalJSON() ([]byte, error) {
if tv != nil || *tv == 0 {
return []byte("null"), nil
}
for s, v := range TLSVersions {
if *tv == v {
return []byte(s), nil
}
}
return nil, fmt.Errorf("unknown TLS version: %d", tv)
}

// BasicAuth contains basic HTTP authentication credentials.
type BasicAuth struct {
Username string `yaml:"username" json:"username"`
Expand Down Expand Up @@ -669,7 +731,10 @@ func cloneRequest(r *http.Request) *http.Request {

// NewTLSConfig creates a new tls.Config from the given TLSConfig.
func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
tlsConfig := &tls.Config{InsecureSkipVerify: cfg.InsecureSkipVerify}
tlsConfig := &tls.Config{
InsecureSkipVerify: cfg.InsecureSkipVerify,
MinVersion: uint16(cfg.MinVersion),
}

// If a CA cert is provided then let's read it in so we can validate the
// scrape target's certificate properly.
Expand Down Expand Up @@ -714,6 +779,8 @@ type TLSConfig struct {
ServerName string `yaml:"server_name,omitempty" json:"server_name,omitempty"`
// Disable target certificate validation.
InsecureSkipVerify bool `yaml:"insecure_skip_verify" json:"insecure_skip_verify"`
// Minimum TLS version.
MinVersion TLSVersion `yaml:"min_version,omitempty" json:"min_version,omitempty"`
}

// SetDirectory joins any relative file paths with dir.
Expand All @@ -726,12 +793,6 @@ func (c *TLSConfig) SetDirectory(dir string) {
c.KeyFile = JoinDir(dir, c.KeyFile)
}

// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain TLSConfig
return unmarshal((*plain)(c))
}

// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
func (c *TLSConfig) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
Expand Down
6 changes: 4 additions & 2 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ func TestTLSConfig(t *testing.T) {
CertFile: ClientCertificatePath,
KeyFile: ClientKeyNoPassPath,
ServerName: "localhost",
InsecureSkipVerify: false}
InsecureSkipVerify: false,
}

tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath)
if err != nil {
Expand All @@ -640,7 +641,8 @@ func TestTLSConfig(t *testing.T) {
expectedTLSConfig := &tls.Config{
RootCAs: rootCAs,
ServerName: configTLSConfig.ServerName,
InsecureSkipVerify: configTLSConfig.InsecureSkipVerify}
InsecureSkipVerify: configTLSConfig.InsecureSkipVerify,
}

tlsConfig, err := NewTLSConfig(&configTLSConfig)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions config/testdata/tls_config.empty.good.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions config/testdata/tls_config.insecure.good.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"insecure_skip_verify": true}
1 change: 1 addition & 0 deletions config/testdata/tls_config.tlsversion.good.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"min_version": "TLS11"}
1 change: 1 addition & 0 deletions config/testdata/tls_config.tlsversion.good.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
min_version: TLS11
37 changes: 33 additions & 4 deletions config/tls_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,39 @@
package config

import (
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
"testing"

"encoding/json"

"gopkg.in/yaml.v2"
)

// LoadTLSConfig parses the given YAML file into a tls.Config.
// LoadTLSConfig parses the given file into a tls.Config.
func LoadTLSConfig(filename string) (*tls.Config, error) {
content, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
cfg := TLSConfig{}
if err = yaml.UnmarshalStrict(content, &cfg); err != nil {
return nil, err
switch filepath.Ext(filename) {
case ".yml":
if err = yaml.UnmarshalStrict(content, &cfg); err != nil {
return nil, err
}
case ".json":
decoder := json.NewDecoder(bytes.NewReader(content))
decoder.DisallowUnknownFields()
if err = decoder.Decode(&cfg); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("Unknown extension: %s", filepath.Ext(filename))
}
return NewTLSConfig(&cfg)
}
Expand All @@ -39,20 +55,33 @@ var expectedTLSConfigs = []struct {
filename string
config *tls.Config
}{
{
filename: "tls_config.empty.good.json",
config: &tls.Config{},
}, {
filename: "tls_config.insecure.good.json",
config: &tls.Config{InsecureSkipVerify: true},
}, {
filename: "tls_config.tlsversion.good.json",
config: &tls.Config{MinVersion: tls.VersionTLS11},
},
{
filename: "tls_config.empty.good.yml",
config: &tls.Config{},
}, {
filename: "tls_config.insecure.good.yml",
config: &tls.Config{InsecureSkipVerify: true},
}, {
filename: "tls_config.tlsversion.good.yml",
config: &tls.Config{MinVersion: tls.VersionTLS11},
},
}

func TestValidTLSConfig(t *testing.T) {
for _, cfg := range expectedTLSConfigs {
got, err := LoadTLSConfig("testdata/" + cfg.filename)
if err != nil {
t.Errorf("Error parsing %s: %s", cfg.filename, err)
t.Fatalf("Error parsing %s: %s", cfg.filename, err)
}
// non-nil functions are never equal.
got.GetClientCertificate = nil
Expand Down

0 comments on commit 3763a1d

Please sign in to comment.