From 3826740db47f5da3c227212baa777ac45823f93a Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 30 Oct 2024 22:16:42 +0100 Subject: [PATCH] Address code review comments II Signed-off-by: Yacov Manevich --- network/peer/tls_config.go | 18 +++++-------- network/peer/tls_config_test.go | 45 ++++++++++++++++++--------------- network/peer/upgrader_test.go | 22 +++++----------- staking/parse_test.go | 2 +- 4 files changed, 37 insertions(+), 50 deletions(-) diff --git a/network/peer/tls_config.go b/network/peer/tls_config.go index f3bb986fa6a3..bce5aa1e7d95 100644 --- a/network/peer/tls_config.go +++ b/network/peer/tls_config.go @@ -11,17 +11,15 @@ import ( "errors" "io" - "golang.org/x/crypto/ed25519" - "github.com/ava-labs/avalanchego/staking" ) var ( - ErrNoCertsSent = errors.New("no certificates sent by peer") - ErrEmptyCert = errors.New("certificate sent by peer is empty") - ErrEmptyPublicKey = errors.New("no public key sent by peer") - ErrCurveMismatch = errors.New("only P256 is allowed for ECDSA") - ErrForbidden25519Key = errors.New("ed25519 is not allowed in this version") + ErrNoCertsSent = errors.New("no certificates sent by peer") + ErrEmptyCert = errors.New("certificate sent by peer is empty") + ErrEmptyPublicKey = errors.New("no public key sent by peer") + ErrCurveMismatch = errors.New("only P256 is allowed for ECDSA") + ErrUnsupportedKeyType = errors.New("key type is not supported") ) // TLSConfig returns the TLS config that will allow secure connections to other @@ -59,8 +57,6 @@ func ValidateCertificate(cs tls.ConnectionState) error { pk := cs.PeerCertificates[0].PublicKey switch key := pk.(type) { - case ed25519.PublicKey: - return ErrForbidden25519Key case *ecdsa.PublicKey: if key == nil { return ErrEmptyPublicKey @@ -71,9 +67,7 @@ func ValidateCertificate(cs tls.ConnectionState) error { return nil case *rsa.PublicKey: return staking.ValidateRSAPublicKeyIsWellFormed(key) - case nil: - return ErrEmptyPublicKey default: - return nil + return ErrUnsupportedKeyType } } diff --git a/network/peer/tls_config_test.go b/network/peer/tls_config_test.go index 0a7f6e69bb50..f9f27ebc6863 100644 --- a/network/peer/tls_config_test.go +++ b/network/peer/tls_config_test.go @@ -30,8 +30,8 @@ func TestValidateCertificate(t *testing.T) { input: func(t *testing.T) tls.ConnectionState { key, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) - x509Cert := makeRSACertAndKey(t, key) - return tls.ConnectionState{PeerCertificates: []*x509.Certificate{&x509Cert.cert}} + x509Cert := makeCert(t, key, &key.PublicKey) + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{x509Cert}} }, }, { @@ -54,9 +54,9 @@ func TestValidateCertificate(t *testing.T) { key, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) - x509CertWithNilPK := makeRSACertAndKey(t, key) - x509CertWithNilPK.cert.PublicKey = (*rsa.PublicKey)(nil) - return tls.ConnectionState{PeerCertificates: []*x509.Certificate{&x509CertWithNilPK.cert}} + x509CertWithNilPK := makeCert(t, key, &key.PublicKey) + x509CertWithNilPK.PublicKey = (*rsa.PublicKey)(nil) + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{x509CertWithNilPK}} }, expectedErr: staking.ErrInvalidRSAPublicKey, }, @@ -66,11 +66,11 @@ func TestValidateCertificate(t *testing.T) { key, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) - x509CertWithNoPK := makeRSACertAndKey(t, key) - x509CertWithNoPK.cert.PublicKey = nil - return tls.ConnectionState{PeerCertificates: []*x509.Certificate{&x509CertWithNoPK.cert}} + x509CertWithNoPK := makeCert(t, key, &key.PublicKey) + x509CertWithNoPK.PublicKey = nil + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{x509CertWithNoPK}} }, - expectedErr: peer.ErrEmptyPublicKey, + expectedErr: peer.ErrUnsupportedKeyType, }, { description: "EC cert", @@ -78,37 +78,40 @@ func TestValidateCertificate(t *testing.T) { ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) - basicCert := basicCert() - certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, &ecKey.PublicKey, ecKey) - require.NoError(t, err) + ecCert := makeCert(t, ecKey, &ecKey.PublicKey) - ecCert, err := x509.ParseCertificate(certBytes) require.NoError(t, err) return tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}} }, }, { description: "EC cert with empty key", - expectedErr: peer.ErrEmptyPublicKey, + expectedErr: peer.ErrUnsupportedKeyType, input: func(t *testing.T) tls.ConnectionState { ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) - basicCert := basicCert() - certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, &ecKey.PublicKey, ecKey) - require.NoError(t, err) + ecCert := makeCert(t, ecKey, &ecKey.PublicKey) + ecCert.PublicKey = nil - ecCert, err := x509.ParseCertificate(certBytes) + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}} + }, + }, + { + description: "EC cert with P384 curve", + expectedErr: peer.ErrCurveMismatch, + input: func(t *testing.T) tls.ConnectionState { + ecKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) require.NoError(t, err) - ecCert.PublicKey = nil + ecCert := makeCert(t, ecKey, &ecKey.PublicKey) return tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}} }, }, { - description: "EC cert with ed25519 key", - expectedErr: peer.ErrForbidden25519Key, + description: "EC cert with ed25519 key not supported", + expectedErr: peer.ErrUnsupportedKeyType, input: func(t *testing.T) tls.ConnectionState { pub, priv, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) diff --git a/network/peer/upgrader_test.go b/network/peer/upgrader_test.go index b2281400e4f7..48bc8e192cd2 100644 --- a/network/peer/upgrader_test.go +++ b/network/peer/upgrader_test.go @@ -91,7 +91,6 @@ func TestBlockClientsWithIncorrectRSAKeys(t *testing.T) { // Initialize upgrader with a mock that fails when it's incremented. failOnIncrementCounter := &mockPrometheusCounter{ Counter: c, - t: t, onIncrement: func() { require.FailNow(t, "should not have invoked") }, @@ -168,10 +167,10 @@ func nonStandardRSAKey(t *testing.T) *rsa.PrivateKey { } func makeTLSCert(t *testing.T, privKey *rsa.PrivateKey) tls.Certificate { - x509Cert := makeRSACertAndKey(t, privKey) + x509Cert := makeCert(t, privKey, &privKey.PublicKey) - rawX509PEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: x509Cert.cert.Raw}) - privateKeyInDER, err := x509.MarshalPKCS8PrivateKey(x509Cert.key) + rawX509PEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: x509Cert.Raw}) + privateKeyInDER, err := x509.MarshalPKCS8PrivateKey(privKey) require.NoError(t, err) privateKeyInPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyInDER}) @@ -182,24 +181,16 @@ func makeTLSCert(t *testing.T, privKey *rsa.PrivateKey) tls.Certificate { return tlsCertServer } -type certAndKey struct { - cert x509.Certificate - key *rsa.PrivateKey -} - -func makeRSACertAndKey(t *testing.T, privKey *rsa.PrivateKey) certAndKey { +func makeCert(t *testing.T, privateKey any, publicKey any) *x509.Certificate { // Create a self-signed cert basicCert := basicCert() - certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, &privKey.PublicKey, privKey) + certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, publicKey, privateKey) require.NoError(t, err) cert, err := x509.ParseCertificate(certBytes) require.NoError(t, err) - return certAndKey{ - cert: *cert, - key: privKey, - } + return cert } func basicCert() *x509.Certificate { @@ -212,7 +203,6 @@ func basicCert() *x509.Certificate { } type mockPrometheusCounter struct { - t *testing.T prometheus.Counter onIncrement func() } diff --git a/staking/parse_test.go b/staking/parse_test.go index 1e642feb400e..b0b032b1f6ea 100644 --- a/staking/parse_test.go +++ b/staking/parse_test.go @@ -52,6 +52,7 @@ func TestValidateRSAPublicKeyIsWellFormed(t *testing.T) { }, { description: "even modulus", + expectErr: ErrRSAModulusIsEven, getPK: func() rsa.PublicKey { evenN := new(big.Int).Set(validKey.N) evenN.Add(evenN, big.NewInt(1)) @@ -60,7 +61,6 @@ func TestValidateRSAPublicKeyIsWellFormed(t *testing.T) { E: 65537, } }, - expectErr: ErrRSAModulusIsEven, }, { description: "unsupported exponent",