From e1f180a119110fec61c2ab577b3eca457ae8c56b Mon Sep 17 00:00:00 2001 From: Zale Young Date: Tue, 14 Jan 2025 10:22:55 -0800 Subject: [PATCH] add context check during state machine initialization in client Summary: Add a check during state machine initialization to ensure that the FizzClientContext is compatible with the Factory. Only checking keyshares and ciphers for now. Sigschemes are a bit more complicated. Reviewed By: mingtaoy Differential Revision: D65295680 fbshipit-source-id: 1ed244a44b23ef964bbf3991beded1cc3cd6832f --- fizz/client/ClientProtocol.cpp | 8 ++ fizz/client/FizzClientContext.cpp | 23 ++++ fizz/client/FizzClientContext.h | 5 + fizz/client/test/ClientProtocolTest.cpp | 71 +++++++++--- fizz/client/test/FizzClientContextTest.cpp | 122 +++++++++++++++++++++ fizz/fizz-config.h.in | 11 ++ 6 files changed, 226 insertions(+), 14 deletions(-) create mode 100644 fizz/client/test/FizzClientContextTest.cpp diff --git a/fizz/client/ClientProtocol.cpp b/fizz/client/ClientProtocol.cpp index df7669ae1ec..49ae01755f5 100644 --- a/fizz/client/ClientProtocol.cpp +++ b/fizz/client/ClientProtocol.cpp @@ -801,6 +801,12 @@ static ClientHello constructEncryptedClientHello( return chloOuter; } +static void checkContext(std::shared_ptr& context) { +#if FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS + context->validate(); +#endif +} + Actions EventHandler::handle( const State& /*state*/, @@ -809,6 +815,8 @@ EventHandler::handle( auto context = std::move(connect.context); + checkContext(context); + // Set up SNI (including possible replacement ECH SNI) folly::Optional echSni; auto sni = std::move(connect.sni); diff --git a/fizz/client/FizzClientContext.cpp b/fizz/client/FizzClientContext.cpp index 2e8bc6f1322..5462bcce405 100644 --- a/fizz/client/FizzClientContext.cpp +++ b/fizz/client/FizzClientContext.cpp @@ -17,5 +17,28 @@ FizzClientContext::FizzClientContext() : factory_(std::make_shared()), clock_(std::make_shared()) {} +void FizzClientContext::validate() const { + // TODO: check supported sig schemes + for (auto& c : supportedCiphers_) { + if (!FIZZ_CONTEXT_VALIDATION_SHOULD_CHECK_CIPHER(c)) { + continue; + } + // will throw if factory doesn't support this cipher + factory_->makeAead(c); + } + + for (auto& g : supportedGroups_) { + // will throw if factory doesn't support this named group + factory_->makeKeyExchange(g, KeyExchangeRole::Client); + } + + for (auto& share : defaultShares_) { + if (std::find(supportedGroups_.begin(), supportedGroups_.end(), share) == + supportedGroups_.end()) { + throw std::runtime_error("unsupported named group in default shares"); + } + } +} + } // namespace client } // namespace fizz diff --git a/fizz/client/FizzClientContext.h b/fizz/client/FizzClientContext.h index d661df9001c..4dcb447afeb 100644 --- a/fizz/client/FizzClientContext.h +++ b/fizz/client/FizzClientContext.h @@ -208,6 +208,11 @@ class FizzClientContext { return factory_; } + /* Ensure that the TLS parameters set in this context are valid (eg. + * compatible with the factory, etc.). Will throw if invalid. + */ + virtual void validate() const; + /** * Sets the certificate decompression manager for server certs. */ diff --git a/fizz/client/test/ClientProtocolTest.cpp b/fizz/client/test/ClientProtocolTest.cpp index 012f6b0e89a..1ddf037a029 100644 --- a/fizz/client/test/ClientProtocolTest.cpp +++ b/fizz/client/test/ClientProtocolTest.cpp @@ -33,8 +33,13 @@ namespace test { class ClientProtocolTest : public ProtocolTest { public: + class ContextWithMockValidate : public FizzClientContext { + public: + MOCK_METHOD(void, validate, (), (const override)); + }; + void SetUp() override { - context_ = std::make_shared(); + context_ = std::make_shared(); context_->setSupportedVersions({ProtocolVersion::tls_1_3}); context_->setSupportedCiphers( {CipherSuite::TLS_AES_128_GCM_SHA256, @@ -279,7 +284,13 @@ class ClientProtocolTest : public ProtocolTest { void doFinishedFlow(ClientAuthType authType); - std::shared_ptr context_; + void maybeExpectValidate() { +#if FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS + EXPECT_CALL(*context_, validate()).Times(1); +#endif + } + + std::shared_ptr context_; MockPlaintextReadRecordLayer* mockRead_; MockPlaintextWriteRecordLayer* mockWrite_; MockEncryptedWriteRecordLayer* mockEarlyWrite_; @@ -355,6 +366,9 @@ TEST_F(ClientProtocolTest, TestConnectFlow) { })); return ret; })); + + maybeExpectValidate(); + MockKeyExchange* mockKex; EXPECT_CALL( *factory_, makeKeyExchange(NamedGroup::x25519, KeyExchangeRole::Client)) @@ -398,6 +412,22 @@ TEST_F(ClientProtocolTest, TestConnectFlow) { EXPECT_FALSE(state_.earlyDataParams().has_value()); } +TEST_F(ClientProtocolTest, TestConnectInvalidContext) { +#if FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS + EXPECT_CALL(*context_, validate()).Times(1).WillRepeatedly(Invoke([]() { + throw std::runtime_error("unsupported parameter"); + })); + + Connect connect; + connect.context = context_; + fizz::Param param = std::move(connect); + + auto actions = detail::processEvent(state_, param); + + expectError(actions, {}, "unsupported parameter"); +#endif +} + TEST_F(ClientProtocolTest, TestConnectPskFlow) { auto psk = getCachedPsk(); EXPECT_CALL(*factory_, makePlaintextReadRecordLayer()) @@ -420,6 +450,9 @@ TEST_F(ClientProtocolTest, TestConnectPskFlow) { })); return ret; })); + + maybeExpectValidate(); + MockKeyExchange* mockKex; EXPECT_CALL( *factory_, makeKeyExchange(NamedGroup::x25519, KeyExchangeRole::Client)) @@ -516,6 +549,9 @@ TEST_F(ClientProtocolTest, TestConnectPskEarlyFlow) { })); return ret; })); + + maybeExpectValidate(); + MockKeyExchange* mockKex; EXPECT_CALL( *factory_, makeKeyExchange(NamedGroup::x25519, KeyExchangeRole::Client)) @@ -832,6 +868,9 @@ TEST_F(ClientProtocolTest, TestConnectSniExtFirst) { TEST_F(ClientProtocolTest, TestConnectMultipleShares) { MockKeyExchange* mockKex1; MockKeyExchange* mockKex2; + + maybeExpectValidate(); + EXPECT_CALL( *factory_, makeKeyExchange(NamedGroup::x25519, KeyExchangeRole::Client)) .WillOnce(InvokeWithoutArgs([&mockKex1]() { @@ -843,6 +882,7 @@ TEST_F(ClientProtocolTest, TestConnectMultipleShares) { mockKex1 = ret.get(); return ret; })); + EXPECT_CALL( *factory_, makeKeyExchange(NamedGroup::secp256r1, KeyExchangeRole::Client)) @@ -872,6 +912,9 @@ TEST_F(ClientProtocolTest, TestConnectMultipleShares) { TEST_F(ClientProtocolTest, TestConnectCachedGroup) { context_->setDefaultShares({NamedGroup::x25519}); + + maybeExpectValidate(); + MockKeyExchange* mockKex; EXPECT_CALL( *factory_, @@ -1051,8 +1094,8 @@ TEST_F(ClientProtocolTest, TestConnectECH) { connect.sni = "www.hostname.com"; const auto& actualChlo = getDefaultClientHello(); - // Two randoms should be generated, 1 for the client hello inner and 1 for the - // client hello outer. + // Two randoms should be generated, 1 for the client hello inner and 1 for + // the client hello outer. EXPECT_CALL(*factory_, makeRandomBytes(_, 32)).Times(2); fizz::Param param = std::move(connect); @@ -1118,8 +1161,8 @@ TEST_F(ClientProtocolTest, TestConnectECHWithHybridSupportedGroup) { connect.sni = "www.hostname.com"; const auto& actualChlo = getDefaultClientHello(); - // Two randoms should be generated, 1 for the client hello inner and 1 for the - // client hello outer. + // Two randoms should be generated, 1 for the client hello inner and 1 for + // the client hello outer. EXPECT_CALL(*factory_, makeRandomBytes(_, 32)).Times(2); fizz::Param param = std::move(connect); @@ -1186,8 +1229,8 @@ TEST_F(ClientProtocolTest, TestConnectECHWithAEGIS) { connect.sni = "www.hostname.com"; const auto& actualChlo = getDefaultClientHello(); - // Two randoms should be generated, 1 for the client hello inner and 1 for the - // client hello outer. + // Two randoms should be generated, 1 for the client hello inner and 1 for + // the client hello outer. EXPECT_CALL(*factory_, makeRandomBytes(_, 32)).Times(2); fizz::Param param = std::move(connect); @@ -3063,8 +3106,8 @@ TEST_F(ClientProtocolTest, TestHelloRetryRequestECHFlow) { // Add the extension to the inner one chlo.extensions.push_back(encodeExtension(ech::InnerECHClientHello())); - // Save this one (the real one), then blank the legacy session id and emplace - // OuterExtensions for AAD construction + // Save this one (the real one), then blank the legacy session id and + // emplace OuterExtensions for AAD construction auto encodedClientHelloInner = encodeHandshake(chlo.clone()); chlo.legacy_session_id = folly::IOBuf::copyBuffer(""); @@ -3334,8 +3377,8 @@ TEST_F(ClientProtocolTest, TestHelloRetryRequestECHRejectedFlow) { // Add the extension to the inner one chlo.extensions.push_back(encodeExtension(ech::InnerECHClientHello())); - // Save this one (the real one), then blank the legacy session id and emplace - // OuterExtensions for AAD construction + // Save this one (the real one), then blank the legacy session id and + // emplace OuterExtensions for AAD construction auto encodedClientHelloInner = encodeHandshake(chlo.clone()); chlo.legacy_session_id = folly::IOBuf::copyBuffer(""); @@ -5897,8 +5940,8 @@ TEST_F(ClientProtocolTest, TestPskWithoutCerts) { // Because CachedPsks can be serialized, and because certificates may fail // to serialize for whatever reason, there may be an instance where a client // uses a deserialized cached psk that does not contain either a client or - // a server certificate, but the PSK itself is valid (and the server accepted - // the offered PSK). + // a server certificate, but the PSK itself is valid (and the server + // accepted the offered PSK). setupExpectingServerHello(); CachedPsk psk = getCachedPsk(); diff --git a/fizz/client/test/FizzClientContextTest.cpp b/fizz/client/test/FizzClientContextTest.cpp new file mode 100644 index 00000000000..bae060c8991 --- /dev/null +++ b/fizz/client/test/FizzClientContextTest.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include + +using namespace fizz::test; + +namespace fizz { +namespace client { +namespace test { + +class FizzClientContextTest : public ::testing::Test { + public: + void SetUp() override { + auto mockFactory = std::make_shared(); + mockFactory->setDefaults(); + factory_ = mockFactory.get(); + + context_ = std::make_shared(mockFactory); + } + + void expectValidateThrows(std::string msg) { + try { + context_->validate(); + } catch (const std::exception& error) { + EXPECT_THAT(error.what(), HasSubstr(msg)); + return; + } + // shouldn't reach here + ASSERT_TRUE(false); + } + + std::shared_ptr context_; + MockFactory* factory_; +}; + +TEST_F(FizzClientContextTest, TestValidateUnsupportedCipher) { + const auto unsupportedCipher = static_cast(0xFFFF); + EXPECT_CALL(*factory_, makeAead(_)).WillRepeatedly([](CipherSuite cipher) { + if (cipher == unsupportedCipher) { + throw std::runtime_error("unsupported cipher"); + } else { + return std::make_unique(); + } + }); + + context_->setSupportedCiphers({unsupportedCipher}); + + expectValidateThrows("unsupported cipher"); +} + +TEST_F(FizzClientContextTest, TestValidateUnsupportedGroup) { + const auto unsupportedGroup = static_cast(0xFFFF); + EXPECT_CALL(*factory_, makeKeyExchange(_, _)) + .WillRepeatedly([](NamedGroup group, KeyExchangeRole /*unused*/) { + if (group == unsupportedGroup) { + throw std::runtime_error("unsupported group"); + } else { + return std::make_unique(); + } + }); + + context_->setSupportedGroups({unsupportedGroup}); + + expectValidateThrows("unsupported group"); +} + +TEST_F(FizzClientContextTest, TestValidateUnsupportedDefaultShare) { + context_->setSupportedGroups( + {static_cast(0x01), + static_cast(0x02)}); + + context_->setDefaultShares( + {static_cast(0x02), + static_cast(0x03)}); + + expectValidateThrows("unsupported named group in default shares"); +} + +TEST_F(FizzClientContextTest, TestValidateSuccess) { + EXPECT_CALL(*factory_, makeAead(_)).WillRepeatedly([](CipherSuite cipher) { + if (cipher == static_cast(0xFFFF)) { + throw std::runtime_error("unsupported cipher"); + } else { + return std::make_unique(); + } + }); + EXPECT_CALL(*factory_, makeKeyExchange(_, _)) + .WillRepeatedly([](NamedGroup group, KeyExchangeRole /*unused*/) { + if (group == static_cast(0xFFFF)) { + throw std::runtime_error("unsupported group"); + } else { + return std::make_unique(); + } + }); + + context_->setSupportedCiphers( + {static_cast(0x01), + static_cast(0x02)}); + + context_->setSupportedGroups( + {static_cast(0x03), + static_cast(0x04)}); + + context_->setDefaultShares({static_cast(0x03)}); + + EXPECT_NO_THROW(context_->validate()); +} +} // namespace test +} // namespace client +} // namespace fizz diff --git a/fizz/fizz-config.h.in b/fizz/fizz-config.h.in index 72a638afd18..79e62827fc0 100644 --- a/fizz/fizz-config.h.in +++ b/fizz/fizz-config.h.in @@ -15,5 +15,16 @@ #cmakedefine01 FIZZ_CERTIFICATE_USE_OPENSSL_CERT #cmakedefine01 FIZZ_HAVE_OQS +#if !defined(FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS) +#if defined(NDEBUG) +#define FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS 0 +#else +#define FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS 1 +#endif +#endif + +#define FIZZ_CONTEXT_VALIDATION_SHOULD_CHECK_CIPHER(x) (true) + #define FIZZ_DEFAULT_FACTORY_HEADER #define FIZZ_DEFAULT_FACTORY ::fizz::MultiBackendFactory +