diff --git a/network/peer/8192RSA_test.pem b/network/peer/8192RSA_test.pem new file mode 100644 index 000000000000..77d2ce8c7826 --- /dev/null +++ b/network/peer/8192RSA_test.pem @@ -0,0 +1,100 @@ +-----BEGIN ----- +MIISRAIBADANBgkqhkiG9w0BAQEFAASCEi4wghIqAgEAAoIEAQDBND9cycCvR+iY +uKrG58jvz5VcU+V4ff9AQ4jUXAUCcUTtxz/82BCSBS0SJZLPJq3HHMHFZ8Qnjrfo +yMYACZ5wnZvpyUdOMVqAka6k1IYLkD/+sQYN92c6GZIOxNciBrox804jRTGl9Frg +5AX4cmmGBXyhJVm0waxl2PAqdayPJ/cf+CSWn3XyzkE/irBHFLFny9BnH5h111lO +4AMpl8yojVWPqPVuj/S5JJ7z9Mnut6nMHEumy184+C952NhGMmlYhX4OTW/S/Ylp +b+vURMk44/iMCJ2XanSD85lEZOJKIq05e9vNdcIn23F//e//oD4zMpAfiwWOkcNP +tva8OLpAZcREbqHn/2skTkl25hLOtUubChF/Zfj5xAnbcD18iyIug/KryyQsA0HK +e2xzdBOkj5axRzOTIgg8bwz+eEoGLM0kmSxd5TQXA130bP7zhAyZ2Bb4pJsPC03/ +h8yc4lBhsagc1s2+tj/tdRk/fAwadwp2V1+hPI70dM13jt7xeNh6yyqNcP13Ky1Y +ajfYYsXSkSVOsovJmUyO7EqeHlP5bxUirljSG+RXqcbvdzvEaFTfM5iGMaHxn7IC +i8pJPUUCDseqsT0VFXD5UYc/G4rxSFPrecEQN14wHcMyUzO4r8CLjidBUJn6HQcQ +RiIP+uFSF70UAz5XvdmyqLeXS/ZhsMqLBH1qvsnd3cvGEStDjsujZf/u4VhOKe8f +SV9U2c7YNu22ShX1BnRzuZgDavnNr/Be+eyxwZ8z1XtGTqOMvL18kPMo9RGxaoKL +if+0yahFBGkY8TsmbB578Vxn0EYFx6neNw6H1HHpRS5osraG4iJTpBfgwVLyRLBC +aLB6mnsTie3bjUbIJCaFeOVPNrlyzNbuG22SJYCrOXIXzBNjh+GHNT//ubP/DWoz +wkVAHhooz315KMFiOyKfrlaoVa9IOFBEFJ2VmedA+j3VkJSfocBBGrdDGT0XijXX +ZCUSwn5EY7Ks5/gb3jGcpts028mTwkaLLRqL5e40ZF5FLmJHih7sbl6YyrUbg/Lv +ihe49CXO4DqWyCgq5+Zk1nq2eWalgsHpGchdMOPPkJHY/8HdAiBt/4trD12PstF3 +qy8rmzTh8CH+RzWTcnb5A0bwgVcTZErEbKcE938pfsFgenHQmAs2ryRIMmMnu2ck +GlrJzowYVmWFBPIIoT5BsytWYLluErs9p8FjhHhTA1z9AE56cyu5qwz46i4TAysF +TJ+LITvQLeVt1lYooXNfUXsdxRbBqKXb2kzkACWNhtVeofLsqFIxRM4z9F7X9r0h +skXG+42WZRrLK/3hW5NtQ2BX3WvEjBBlGwR4n6IcCJeX6plL0exiYMUuPqwogJBv +o3h/gCD5AgMBAAECggQASQ1EWAVBAgWigPxyNjs10tceloZyYZjihp4Cgqk4i6/g +bDfGjgf0XAHxBMeINyNc2ciZy9ZsaLih+TbRBvqcGeC+LyuX9ozat3peGpzxAjZM +vDSbIXTGZ0V74HG1FnyMso5YoSVsnF9EbXxKdaJtG+u/L/87aAlC8k+Qn71WvdpS +qpfc3cb1hhVOvoPmGzpLyf9akWN09jmy3wv8piFrlN+71lIAWwm7crXSFFQedlCj +tzWLtUl4e8X7zYqcXA57nqj6/NVyzshmyKM0/FH187jfJbOsQrBR1gKplR7AIV/z +N6UJeypne0KSK98MfA9O9XTM4eBi/YFH5EA+EvUwF2FjUKy0M1B0ZonjZT2hJt+N +8tVfwFgCSA5D2+EYnprNFeF2RFbPGoUwvyrj2tOtCa/xPp65dYyMqK0ksKMy+hq+ +hnQUPnyHsZvoTp9X1yO60ADQzrsOliWkHFZwm3FHC2ltM1pU+SNYEKUSItr4iJky +L4Th98k6FFyFxAsVaSBUWjmvoUNz0zdUMfYXn43ZVsDi5lrEWDnKpM/bduXowoup +5i8eDnPVZwAe5DSlOKJqVOrhZPwnS4EigavxlLfB/AEypevWOL6etOaKyOXVJ149 +vO+QfF0zE+ZtA/5JtC9gEmRxm1Sqo9ON9C1Qe9JUmAG50HNZgzuZsN/yaxah1lWl +0It7akaxJgeEl747FIsi8Q7+/6hrhKapEFrJjH/1KNy2yZc9S6e7N+hQkAO5iN0G +B3R/bVe+DFqK0iXas4Eac9/mnfqDNbPkxFX7OZkAFqDhbj+4j8JsVTTmCiPXiZQc +FvzHsF3WL+Gp+/iui80530Wds9hQ/z7rwIDsMRX/+TxzhkIHTl3BVZ1IbvWcf6Nh +Ieyp7Umr9+oaJr5sHuWwh+aa5vkIz+2l+wcDE9gfkxclWSxKhumb+SYH404JdORs +ZPm86sBmF1ePcOQ+T6oNVTPnM8pDWyO4E0bYsv/qso/ZEd0tIQiwQOoTyJ1kLe5X +5TdXYmV3oby2bp7FM3Tb1AX2e7LFOboxMImU9z+X9+orXy6IEky7QfE3w2SGzgmv +JrZ18vtpAVgFefqqj3PnmJlc8q6YbU31wHwnyiSZDSe0TEPr4hOPUOe/Qn7KJVy6 +fs+K6PWdcIDNZx9jgnTWPhkWYM0pGznAAy+V8PlCfQMc57I47HhnFzR4YGWWfox0 +K3TI7yUJJpa17TSTKUwsuFNTTahxC3439u1wj6xBdDf+vR0tq6svCmzBcJ15vS4r +yrm8J00OtPnlA1zoYlfWwDvNIFvQ5xgxNEcpnD2jUoe08FaYcZ6Xxz+92OhGNYPm +67d7P55mST60fa8fNWM82BW+tV5oo5+1n/uphjxQcQKCAgEA8EYIqh+ij5bL1+af +zf8BgB/ovdwKOma6MzHop1+fW+wRr92u7mYhkaI1bA+/X9FrC844Lb7oJ0w8E3gg +WMyw9DGHEqk+/FCiEfb/U8D9h0/zV8YTPlV1/sZmmwQgItoiaIZP6LOiNTevhjL6 +JNDM+Uh8Wd15AT9RA2vdPWdWR0DvEbRXPx2RkFpHoBDiI6+CuJ+vgpBbSSU7PTZ5 ++C8PTuFMh0V7fp16rufs8HNPt/CE/5z4KhVIArfgh6yF8VzW7Fz+RHL6Ao3n5khj +0/14nblKBEfeNpDcCMDRKPYxPi2iTvvMmk8ZZg1XsbdJQZbpUBfp66kDRlGuamlG +92Nqjl2+fPOJhmKQYls1oFL1Gp0RIQeYmJvVRwQlazo2a94QJN5A8MkhxXjLjHGz +IKKrwbJ4sRAS0GPs3hruRggDGoJpXMKWXa98cUp9KOYVvBnNoQLA6V8WDfqRCV58 +BY/ogpPGeJcoDE6h//SPL4vp2r2rGy9GGPnCbNZGfYuCOHlJvwcddsuM86ApSxOs +rsGMHMz85xtINlhIQgTyr13+EcZEkfB0j9Yv1mdlpTs5tuDDu09TGUEXAFUiD2Rs +WepKpT7rX+KoIuMjMwuUaLdgXP6OPvsPHqLc3oFKkMY4H0yHOq2bDse80jrEg++U +JbhKDABA0xUSmSCHxGNMTvD4x28CggIBAM3Zh2AwWiLsHCHqPY5/NWenLENeWDhs +Omirl5OkTRGYxrzvELDDnMmyh0QfbdkUn6aoXird8sPGHvjYIgKVxiK/q8Hbz1x+ +a9lqfLV2DBSQZGq3OvpiLubPcHCnJm3+MXL7foer49HatAK7qimcq8SKEMGXx2Fy +pzIHtoBm7JXZqrhr6o6K6GCpPm57YrkqEAz6bu+oCgHzmm4QN5aVcFr3/JxSZbdV +hccRBP87BkzyQrfgDqra2oxvKNV1KL7ByhdQxWq76/YkCI+U59U2rJM6UhRNPizG +/VGrhAahSuZkFkWhwKvg9QxBjixljRwOrFsJW5lmkzX26hnsyv0czE4O+nxtP4Ug +BMQ/QdrwQmjMNp95rYXhbCeirjuc3O3G58uHAgjx+UIopE4SNmRgqQb0tXM2DEN9 +8Ppx1aklcyBt+17qh2ToTkZuW6KK39eYXZi8ZtYl0C908vJ7RUQGlL+wwlvVswCv +KuQOXWl4Kv1oVKf8k/0lqtZh7LQwTGUPspndGyBxuN5f9UFvOxLpy5k6yA/J7Idp +k/0nzQei++UltQ//LiyMSb/JpA7GAEE68HoOhDzsablNoLaiIPrUfo7SLFUw7n/H +m2CiSF21KBEPuq5waXpe8aeYtbCcXjCZI34Swkt892mws01XY4AiJ5xUizv+gcJD +hADngf/XcioXAoICAQDH89BEG12CFyD+RCubF2sdP/DFB3fvkAvGjPMrTpVkvvkd +HOP1+0JWWuIQUq6VQ8bMpUn1L9ks0vFv1lk87OMZ5Jmeuv/yo/ur7ZwgDAwwbiV5 +VxoulppCcsNyn6VKu7NEvvmDEvKbTQMiMAwhVS4vCdaKRpfrpNB7g2kzL2sKkwwg +9K5ilO3NboQKveIjhmzHzgQWKKH/Jh+9Wjd4hVk88JtqOzWBcfZl1hZFKAEgduWH +fw66nsk1keYlojo5WWR2gREMz44lUAi7iGSjR134C/l/xHs1d6nVEvk9GFx0fS+E +gWGMzOS7G8Ft4LTzA26YO75sYlOaUmFOptvrBm3nmjXq8BTzo9S6NWNUT5UwF6Po +k9S2s4BywA2PxXsCm2Nd+yOZ/he/qT3jW7+RGi7LXAW6fEDb8TxuvYSq/QHwLrUV +/814m5B5C19LCObviZ2pL4xw6bOF4I6QeHPHgTIicG4LbudiDpIcWl5KWCo94fei +AN5Z7IeTYWJ6Gf49lxn7AiXP9acQG6ohk3byW5mJYkHY5chbiW5gmpOHwzWrfw8T +UEMAbGOVDqj1L2thOH1KxMHH03Ybzb0xiAXvcd261Li2K/52QgXJ9goEdw6XdTPV +T8MOYMRj2r696mdMDLjA6TaPv0LwxP1DOr5UAaCFijRoNTIsAnlZwrT/QOQXuwKC +AgEAxqtiF3izFa9Q+36KSIQHc/GJK7/bXyE9QhYR5aGV7BzJ+kC0mBVCtfuCx0Ga +D//ykbM/pxmsmjwVWk+mi14n6xOX3jKaMAenaR94Gt5CjHpLIB+VYV/vKj4co+z+ +jvvcl7+X/7Lq3ne4ckbS1PRrZvVldKJbAHbaXNPK1KQBRCLevL0SlN4FpnzRT2nv +/wtUkGIHPW+tsPJ+IimurLuvw2xBtlFj8AwvX8/SRc6epxbNQ4+QOF+evBjwjQtU +9r4roFMJJZkXA+kFBiZNlZ798d5Ap21hS3AFvpPNiWST2EXSpQOW44vqlRiT8c9U +4DZdLEOczzGLdHLIv5qk0qK/n7qfEAWUX5RmZU0z7u0g+unU8hdKXMMSUjKU+93J +8AafYfP8B8wZqDt3UA4NxtTvbVIx6W7JaT4cnGnPLz+AnFTpXVL2t3HpUdpiwD5O +CVL5SlbS3W2DProdW9+TGzNKzrL28hEOgOOOfqpKh2c9/nJ5+eMwpQp8lgnOnJ1c +rdD3q74U1zxKkvyDxNJobjmMkWeE/JACozJHbPXD0NIBUMgStsyusLn414vxtXxt +dIdA3lwyTmZRJ1F/gaR6Nftt5cN8m//svxBTqnEVbLNRZx4KKx88/aiyi/E7sadI +1JiIA75xHNAQLUYn1sY3tsu/9QY3lwBsFaR5uzG0aspxWaMCggIBAM10S3/g8Zbi +DGX70XlgNWBQMtlKhQWK64I3VdP79pJ8LZu2mPnOmwV6oa+MAW2pKWJ96ZuARj1k +jOSzi5022qfBQOx2zD1jDOWZa2FMz7gyJWiLKSBG3b0UVWa+2OxxIkIK2pEWYURX +qvbVy/6Kh34xupzHYAvyYraac2NKIpZHmxVxYZjazm7Ot6bn7oXG09D11oXHq4/r +u9hGCkOVD9pKu7it/8IQMyNNCm8Sw8ShYLrA6PYtGOqV6ByuUp7EgcJzbPNGHlXX +p7fwIsxInWhZXqFz8ARji+za4G65vr+tjQzBMGL6V/QWNzM8CRgXQJwjr50bk79t +U64t/I/bHTnQiEUMqkdE4ly1A3QxttUaCN6s63Rj+Pfy7cJZZOehSFn0YAH16BYH +NUrjQxhKy3U5UdWGxp9V8wZ0YW9ThGJM8g/n3PZjhQK/LilwUkDk4eGwMspyh09z +3Azsfvl+nsiNBu3ft3dLggX729cWKkGg1Kdv7OTcmxLTipkYESXH708SsGlJhqSq +YW9sl2Ankve2apwrLTLQUE06o4KF5F9KP2MC3uXiwYAmoHuyYMnZ9fcfH2P1zbQL +czHlpnSg1/TDbJfAU9E9LauPuha63KrqFaPNOpjDFNyotPxbe+Oy35B0UX5ZVamN +Jk2UczPy2rNK8CZ6iIuChYZ5f3gujYHb +-----END ----- diff --git a/network/peer/tls_config.go b/network/peer/tls_config.go index 9673b98dc8f1..bce5aa1e7d95 100644 --- a/network/peer/tls_config.go +++ b/network/peer/tls_config.go @@ -4,8 +4,22 @@ package peer import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" "crypto/tls" + "errors" "io" + + "github.com/ava-labs/avalanchego/staking" +) + +var ( + ErrNoCertsSent = errors.New("no certificates sent by peer") + ErrEmptyCert = errors.New("certificate sent by peer is empty") + ErrEmptyPublicKey = errors.New("no public key sent by peer") + ErrCurveMismatch = errors.New("only P256 is allowed for ECDSA") + ErrUnsupportedKeyType = errors.New("key type is not supported") ) // TLSConfig returns the TLS config that will allow secure connections to other @@ -26,5 +40,34 @@ func TLSConfig(cert tls.Certificate, keyLogWriter io.Writer) *tls.Config { InsecureSkipVerify: true, //#nosec G402 MinVersion: tls.VersionTLS13, KeyLogWriter: keyLogWriter, + VerifyConnection: ValidateCertificate, + } +} + +// ValidateCertificate validates TLS certificates according their public keys on the leaf certificate in the certification chain. +func ValidateCertificate(cs tls.ConnectionState) error { + if len(cs.PeerCertificates) == 0 { + return ErrNoCertsSent + } + + if cs.PeerCertificates[0] == nil { + return ErrEmptyCert + } + + pk := cs.PeerCertificates[0].PublicKey + + switch key := pk.(type) { + case *ecdsa.PublicKey: + if key == nil { + return ErrEmptyPublicKey + } + if key.Curve != elliptic.P256() { + return ErrCurveMismatch + } + return nil + case *rsa.PublicKey: + return staking.ValidateRSAPublicKeyIsWellFormed(key) + default: + return ErrUnsupportedKeyType } } diff --git a/network/peer/tls_config_test.go b/network/peer/tls_config_test.go new file mode 100644 index 000000000000..41cf6a4dca8d --- /dev/null +++ b/network/peer/tls_config_test.go @@ -0,0 +1,134 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package peer_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ed25519" + + "github.com/ava-labs/avalanchego/network/peer" + "github.com/ava-labs/avalanchego/staking" +) + +func TestValidateCertificate(t *testing.T) { + for _, testCase := range []struct { + description string + input func(t *testing.T) tls.ConnectionState + expectedErr error + }{ + { + description: "Valid TLS cert", + input: func(t *testing.T) tls.ConnectionState { + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + x509Cert := makeCert(t, key, &key.PublicKey) + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{x509Cert}} + }, + }, + { + description: "No TLS certs given", + input: func(*testing.T) tls.ConnectionState { + return tls.ConnectionState{} + }, + expectedErr: peer.ErrNoCertsSent, + }, + { + description: "Empty certificate given by peer", + input: func(*testing.T) tls.ConnectionState { + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{nil}} + }, + expectedErr: peer.ErrEmptyCert, + }, + { + description: "nil RSA key", + input: func(t *testing.T) tls.ConnectionState { + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + x509CertWithNilPK := makeCert(t, key, &key.PublicKey) + x509CertWithNilPK.PublicKey = (*rsa.PublicKey)(nil) + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{x509CertWithNilPK}} + }, + expectedErr: staking.ErrInvalidRSAPublicKey, + }, + { + description: "No public key in the cert given", + input: func(t *testing.T) tls.ConnectionState { + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + x509CertWithNoPK := makeCert(t, key, &key.PublicKey) + x509CertWithNoPK.PublicKey = nil + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{x509CertWithNoPK}} + }, + expectedErr: peer.ErrUnsupportedKeyType, + }, + { + description: "EC cert", + input: func(t *testing.T) tls.ConnectionState { + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + ecCert := makeCert(t, ecKey, &ecKey.PublicKey) + + require.NoError(t, err) + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}} + }, + }, + { + description: "EC cert with empty key", + expectedErr: peer.ErrEmptyPublicKey, + input: func(t *testing.T) tls.ConnectionState { + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + ecCert := makeCert(t, ecKey, &ecKey.PublicKey) + ecCert.PublicKey = (*ecdsa.PublicKey)(nil) + + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}} + }, + }, + { + description: "EC cert with P384 curve", + expectedErr: peer.ErrCurveMismatch, + input: func(t *testing.T) tls.ConnectionState { + ecKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(t, err) + + ecCert := makeCert(t, ecKey, &ecKey.PublicKey) + + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}} + }, + }, + { + description: "EC cert with ed25519 key not supported", + expectedErr: peer.ErrUnsupportedKeyType, + input: func(t *testing.T) tls.ConnectionState { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + basicCert := basicCert() + certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, pub, priv) + require.NoError(t, err) + + ecCert, err := x509.ParseCertificate(certBytes) + require.NoError(t, err) + + return tls.ConnectionState{PeerCertificates: []*x509.Certificate{ecCert}} + }, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + require.Equal(t, testCase.expectedErr, peer.ValidateCertificate(testCase.input(t))) + }) + } +} diff --git a/network/peer/upgrader_test.go b/network/peer/upgrader_test.go new file mode 100644 index 000000000000..48bc8e192cd2 --- /dev/null +++ b/network/peer/upgrader_test.go @@ -0,0 +1,212 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package peer_test + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" + "net" + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + + _ "embed" + + "github.com/ava-labs/avalanchego/network/peer" + "github.com/ava-labs/avalanchego/staking" +) + +// 8192RSA_test.pem is used here because it's too expensive +// to generate an 8K bit RSA key under the time constraint of the weak Github CI runners. + +//go:embed 8192RSA_test.pem +var fat8192BitRSAKey []byte + +func TestBlockClientsWithIncorrectRSAKeys(t *testing.T) { + for _, testCase := range []struct { + description string + genClientTLSCert func() tls.Certificate + expectedErr error + }{ + { + description: "Proper key size and private key - 2048", + genClientTLSCert: func() tls.Certificate { + privKey2048, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + clientCert2048 := makeTLSCert(t, privKey2048) + return clientCert2048 + }, + }, + { + description: "Proper key size and private key - 4096", + genClientTLSCert: func() tls.Certificate { + privKey4096, err := rsa.GenerateKey(rand.Reader, 4096) + require.NoError(t, err) + clientCert4096 := makeTLSCert(t, privKey4096) + return clientCert4096 + }, + }, + { + description: "Too big key", + genClientTLSCert: func() tls.Certificate { + block, _ := pem.Decode(fat8192BitRSAKey) + require.NotNil(t, block) + rsaFatKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + require.NoError(t, err) + privKey8192 := rsaFatKey.(*rsa.PrivateKey) + // Sanity check - ensure privKey8192 is indeed an 8192 RSA key + require.Equal(t, 8192, privKey8192.N.BitLen()) + clientCert8192 := makeTLSCert(t, privKey8192) + return clientCert8192 + }, + expectedErr: staking.ErrUnsupportedRSAModulusBitLen, + }, + { + description: "Improper public exponent", + genClientTLSCert: func() tls.Certificate { + clientCertBad := makeTLSCert(t, nonStandardRSAKey(t)) + return clientCertBad + }, + expectedErr: staking.ErrUnsupportedRSAPublicExponent, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + serverKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + serverCert := makeTLSCert(t, serverKey) + + config := peer.TLSConfig(serverCert, nil) + + c := prometheus.NewCounter(prometheus.CounterOpts{}) + + // Initialize upgrader with a mock that fails when it's incremented. + failOnIncrementCounter := &mockPrometheusCounter{ + Counter: c, + onIncrement: func() { + require.FailNow(t, "should not have invoked") + }, + } + upgrader := peer.NewTLSServerUpgrader(config, failOnIncrementCounter) + + clientConfig := tls.Config{ + ClientAuth: tls.RequireAnyClientCert, + InsecureSkipVerify: true, //#nosec G402 + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{testCase.genClientTLSCert()}, + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + conn, err := listener.Accept() + require.NoError(t, err) + + _, _, _, err = upgrader.Upgrade(conn) + + require.ErrorIs(t, err, testCase.expectedErr) + }() + + conn, err := tls.Dial("tcp", listener.Addr().String(), &clientConfig) + require.NoError(t, err) + + require.NoError(t, conn.Handshake()) + + wg.Wait() + }) + } +} + +func nonStandardRSAKey(t *testing.T) *rsa.PrivateKey { + for { + sk, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + // This speeds up RSA operations, and was initialized during the key-gen. + // If we wish to override the key parameters we need to nullify this, + // otherwise the signer will use these values and the verifier will use + // the values we override, and verification will fail. + sk.Precomputed = rsa.PrecomputedValues{} + + // We want a non-standard E, so let's use E = 257 and derive D again. + e := 257 + sk.PublicKey.E = e + sk.E = e + + p := sk.Primes[0] + q := sk.Primes[1] + + pminus1 := new(big.Int).Sub(p, big.NewInt(1)) + qminus1 := new(big.Int).Sub(q, big.NewInt(1)) + + phiN := big.NewInt(0).Mul(pminus1, qminus1) + + sk.D = big.NewInt(0).ModInverse(big.NewInt(int64(e)), phiN) + + if sk.D == nil { + // If we ended up picking a bad starting modulus, try again. + continue + } + + return sk + } +} + +func makeTLSCert(t *testing.T, privKey *rsa.PrivateKey) tls.Certificate { + x509Cert := makeCert(t, privKey, &privKey.PublicKey) + + rawX509PEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: x509Cert.Raw}) + privateKeyInDER, err := x509.MarshalPKCS8PrivateKey(privKey) + require.NoError(t, err) + + privateKeyInPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyInDER}) + + tlsCertServer, err := tls.X509KeyPair(rawX509PEM, privateKeyInPEM) + require.NoError(t, err) + + return tlsCertServer +} + +func makeCert(t *testing.T, privateKey any, publicKey any) *x509.Certificate { + // Create a self-signed cert + basicCert := basicCert() + certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, publicKey, privateKey) + require.NoError(t, err) + + cert, err := x509.ParseCertificate(certBytes) + require.NoError(t, err) + + return cert +} + +func basicCert() *x509.Certificate { + return &x509.Certificate{ + SerialNumber: big.NewInt(0).SetInt64(100), + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour).UTC(), + BasicConstraintsValid: true, + } +} + +type mockPrometheusCounter struct { + prometheus.Counter + onIncrement func() +} + +func (m *mockPrometheusCounter) Inc() { + m.onIncrement() +} diff --git a/staking/parse.go b/staking/parse.go index 28ae02cbc644..6a544ea6332d 100644 --- a/staking/parse.go +++ b/staking/parse.go @@ -136,18 +136,8 @@ func parsePublicKey(oid asn1.ObjectIdentifier, publicKey asn1.BitString) (crypto return nil, ErrInvalidRSAPublicExponent } - if pub.N.Sign() <= 0 { - return nil, ErrRSAModulusNotPositive - } - - if bitLen := pub.N.BitLen(); bitLen != allowedRSALargeModulusLen && bitLen != allowedRSASmallModulusLen { - return nil, fmt.Errorf("%w: %d", ErrUnsupportedRSAModulusBitLen, bitLen) - } - if pub.N.Bit(0) == 0 { - return nil, ErrRSAModulusIsEven - } - if pub.E != allowedRSAPublicExponentValue { - return nil, fmt.Errorf("%w: %d", ErrUnsupportedRSAPublicExponent, pub.E) + if err := ValidateRSAPublicKeyIsWellFormed(pub); err != nil { + return nil, err } return pub, nil case oid.Equal(oidPublicKeyECDSA): @@ -165,3 +155,23 @@ func parsePublicKey(oid asn1.ObjectIdentifier, publicKey asn1.BitString) (crypto return nil, ErrUnknownPublicKeyAlgorithm } } + +// ValidateRSAPublicKeyIsWellFormed validates the given RSA public key +func ValidateRSAPublicKeyIsWellFormed(pub *rsa.PublicKey) error { + if pub == nil { + return ErrInvalidRSAPublicKey + } + if pub.N.Sign() <= 0 { + return ErrRSAModulusNotPositive + } + if bitLen := pub.N.BitLen(); bitLen != allowedRSALargeModulusLen && bitLen != allowedRSASmallModulusLen { + return fmt.Errorf("%w: %d", ErrUnsupportedRSAModulusBitLen, bitLen) + } + if pub.N.Bit(0) == 0 { + return ErrRSAModulusIsEven + } + if pub.E != allowedRSAPublicExponentValue { + return fmt.Errorf("%w: %d", ErrUnsupportedRSAPublicExponent, pub.E) + } + return nil +} diff --git a/staking/parse_test.go b/staking/parse_test.go index 41704e3b71dc..b0b032b1f6ea 100644 --- a/staking/parse_test.go +++ b/staking/parse_test.go @@ -4,6 +4,9 @@ package staking import ( + "crypto/rand" + "crypto/rsa" + "math/big" "testing" "github.com/stretchr/testify/require" @@ -31,3 +34,73 @@ func BenchmarkParse(b *testing.B) { require.NoError(b, err) } } + +func TestValidateRSAPublicKeyIsWellFormed(t *testing.T) { + validKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + for _, testCase := range []struct { + description string + expectErr error + getPK func() rsa.PublicKey + }{ + { + description: "valid public key", + getPK: func() rsa.PublicKey { + return validKey.PublicKey + }, + }, + { + description: "even modulus", + expectErr: ErrRSAModulusIsEven, + getPK: func() rsa.PublicKey { + evenN := new(big.Int).Set(validKey.N) + evenN.Add(evenN, big.NewInt(1)) + return rsa.PublicKey{ + N: evenN, + E: 65537, + } + }, + }, + { + description: "unsupported exponent", + expectErr: ErrUnsupportedRSAPublicExponent, + getPK: func() rsa.PublicKey { + return rsa.PublicKey{ + N: validKey.N, + E: 3, + } + }, + }, + { + description: "unsupported modulus bit len", + expectErr: ErrUnsupportedRSAModulusBitLen, + getPK: func() rsa.PublicKey { + bigMod := new(big.Int).Set(validKey.N) + bigMod.Lsh(bigMod, 2049) + return rsa.PublicKey{ + N: bigMod, + E: 65537, + } + }, + }, + { + description: "not positive modulus", + expectErr: ErrRSAModulusNotPositive, + getPK: func() rsa.PublicKey { + minusN := new(big.Int).Set(validKey.N) + minusN.Neg(minusN) + return rsa.PublicKey{ + N: minusN, + E: 65537, + } + }, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + pk := testCase.getPK() + err := ValidateRSAPublicKeyIsWellFormed(&pk) + require.ErrorIs(t, err, testCase.expectErr) + }) + } +}