Skip to content

Commit

Permalink
Connections now check for in-progress async operations and fail witho…
Browse files Browse the repository at this point in the history
…ut UB

connection and any_connection now check whether there is an in-progress
async operation, and fail with client_errc::operation_in_progress if
there is one. This situation no longer triggers undefined behavior.
Refactored the internal sans-io algorithms

close #405
  • Loading branch information
anarthal authored Feb 5, 2025
1 parent 91cd262 commit 829dbf7
Show file tree
Hide file tree
Showing 53 changed files with 620 additions and 384 deletions.
16 changes: 7 additions & 9 deletions doc/qbk/04_overview.qbk
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,15 @@ for more info.

[section:async Single outstanding async operation per connection]

At any given point in time, a `any_connection` may only have a single async operation outstanding.
Because MySQL sessions are stateful, and to keep the implementation simple, messages
are written to the underlying transport without any locking or queueing.
If you perform several async operations concurrently on a single connection without any
serialization, messages from different operations will be interleaved, leading to undefined behavior.
At any given point in time, an `any_connection` can only have a single async operation outstanding.
In other words, connections implement no asynchronous locking or queueing, which
keeps code simple and efficient. If you need to perform several operations in parallel,
you can open more connections or use [reflink connection_pool].

For example, doing the following is illegal and should be avoided:
Trying to run operations concurrently on a single connection is detected at
runtime and generates a `client_errc::operation_in_progress` error:

[overview_async_dont]

If you need to perform queries in parallel, open more connections to the server.
[overview_async_parallel]

[endsect]

Expand Down
5 changes: 5 additions & 0 deletions include/boost/mysql/any_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ struct any_connection_params
*
* This is a move-only type.
*
* \par Single outstanding async operation per connection
* At any given point in time, only one async operation can be outstanding
* per connection. If an async operation is initiated while another one is in progress,
* it will fail with \ref client_errc::operation_in_progress.
*
* \par Default completion tokens
* The default completion token for all async operations in this class is
* `with_diagnostics(asio::deferred)`, which allows you to use `co_await`
Expand Down
6 changes: 6 additions & 0 deletions include/boost/mysql/client_errc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ enum class client_errc : int
* size. Try increasing \ref any_connection_params::max_buffer_size.
*/
max_buffer_size_exceeded,

/**
* \brief Another operation is currently in progress for this connection. Make sure
* that a single connection does not run two asynchronous operations in parallel.
*/
operation_in_progress,
};

BOOST_MYSQL_DECL
Expand Down
5 changes: 5 additions & 0 deletions include/boost/mysql/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class static_execution_state;
* the stream using \ref connection::stream, and its executor via \ref connection::get_executor. The
* executor used by this object is always the same as the underlying stream.
*
* \par Single outstanding async operation per connection
* At any given point in time, only one async operation can be outstanding
* per connection. If an async operation is initiated while another one is in progress,
* it will fail with \ref client_errc::operation_in_progress.
*
* \par Thread safety
* Distinct objects: safe. \n
* Shared objects: unsafe. \n
Expand Down
2 changes: 2 additions & 0 deletions include/boost/mysql/detail/algo_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ struct pipeline_request_stage;

struct connect_algo_params
{
const void* server_address; // Points to an any_address or an endpoint for the corresponding stream. For
// the templated connection, only valid until the first yield!
handshake_params hparams;
bool secure_channel; // Are we using UNIX sockets or any other secure channel?

Expand Down
24 changes: 14 additions & 10 deletions include/boost/mysql/detail/any_resumable_ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,31 @@ namespace detail {

class any_resumable_ref
{
template <class T>
static next_action do_resume(void* self, error_code ec, std::size_t bytes_transferred)
{
return static_cast<T*>(self)->resume(ec, bytes_transferred);
}

public:
using fn_t = next_action (*)(void*, error_code, std::size_t);

void* algo_{};
fn_t fn_{};

public:
template <class T, class = typename std::enable_if<!std::is_same<T, any_resumable_ref>::value>::type>
explicit any_resumable_ref(T& op) noexcept : algo_(&op), fn_(&do_resume<T>)
{
}

// Allow using standalone functions
any_resumable_ref(void* algo, fn_t fn) noexcept : algo_(algo), fn_(fn) {}

next_action resume(error_code ec, std::size_t bytes_transferred)
{
return fn_(algo_, ec, bytes_transferred);
}

private:
template <class T>
static next_action do_resume(void* self, error_code ec, std::size_t bytes_transferred)
{
return static_cast<T*>(self)->resume(ec, bytes_transferred);
}

void* algo_{};
fn_t fn_{};
};

} // namespace detail
Expand Down
19 changes: 11 additions & 8 deletions include/boost/mysql/detail/connection_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,15 @@ class connection_impl
};

// Connect
static connect_algo_params make_params_connect(const handshake_params& params)
static connect_algo_params make_params_connect(const void* server_address, const handshake_params& params)
{
return connect_algo_params{params, false};
return connect_algo_params{server_address, params, false};
}

static connect_algo_params make_params_connect_v2(const connect_params& params)
{
return connect_algo_params{
&params.server_address,
make_hparams(params),
params.server_address.type() == address_type::unix_path
};
Expand All @@ -251,8 +252,13 @@ class connection_impl
handshake_params params
)
{
eng->set_endpoint(&endpoint);
async_run_impl(*eng, *st, make_params_connect(params), *diag, std::forward<Handler>(handler));
async_run_impl(
*eng,
*st,
make_params_connect(&endpoint, params),
*diag,
std::forward<Handler>(handler)
);
}
};

Expand All @@ -269,7 +275,6 @@ class connection_impl
const connect_params* params
)
{
eng->set_endpoint(&params->server_address);
async_run_impl(*eng, *st, make_params_connect_v2(*params), *diag, std::forward<Handler>(handler));
}
};
Expand Down Expand Up @@ -388,13 +393,11 @@ class connection_impl
diagnostics& diag
)
{
engine_->set_endpoint(&endpoint);
run(make_params_connect(params), err, diag);
run(make_params_connect(&endpoint, params), err, diag);
}

void connect_v2(const connect_params& params, error_code& err, diagnostics& diag)
{
engine_->set_endpoint(&params.server_address);
run(make_params_connect_v2(params), err, diag);
}

Expand Down
1 change: 0 additions & 1 deletion include/boost/mysql/detail/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class engine
virtual ~engine() {}
virtual executor_type get_executor() = 0;
virtual bool supports_ssl() const = 0;
virtual void set_endpoint(const void* endpoint) = 0;
virtual void run(any_resumable_ref resumable, error_code& err) = 0;
virtual void async_run(any_resumable_ref resumable, asio::any_completion_handler<void(error_code)>) = 0;
};
Expand Down
15 changes: 8 additions & 7 deletions include/boost/mysql/detail/engine_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ struct run_algo_op
}
else if (act.type() == next_action_type::connect)
{
BOOST_MYSQL_YIELD(resume_point_, 6, stream_.async_connect(std::move(self)))
BOOST_MYSQL_YIELD(
resume_point_,
6,
stream_.async_connect(act.connect_endpoint(), std::move(self))
)
has_done_io_ = true;
}
else
Expand All @@ -128,7 +132,6 @@ struct run_algo_op
// using executor_type = asio::any_io_executor;
// executor_type get_executor();
// bool supports_ssl() const;
// void set_endpoint(const void* endpoint);
// std::size_t read_some(asio::mutable_buffer, bool use_ssl, error_code&);
// void async_read_some(asio::mutable_buffer, bool use_ssl, CompletinToken&&);
// std::size_t write_some(asio::const_buffer, bool use_ssl, error_code&);
Expand All @@ -137,8 +140,8 @@ struct run_algo_op
// void async_ssl_handshake(CompletionToken&&);
// void ssl_shutdown(error_code&);
// void async_ssl_shutdown(CompletionToken&&);
// void connect(error_code&);
// void async_connect(CompletionToken&&);
// void connect(const void* server_address, error_code&);
// void async_connect(const void* server_address, CompletionToken&&);
// void close(error_code&);
// Async operations are only required to support callback types
// See stream_adaptor for an implementation
Expand All @@ -161,8 +164,6 @@ class engine_impl final : public engine

bool supports_ssl() const override final { return stream_.supports_ssl(); }

void set_endpoint(const void* endpoint) override final { stream_.set_endpoint(endpoint); }

void run(any_resumable_ref resumable, error_code& ec) override final
{
ec.clear();
Expand Down Expand Up @@ -207,7 +208,7 @@ class engine_impl final : public engine
}
else if (act.type() == next_action_type::connect)
{
stream_.connect(io_ec);
stream_.connect(act.connect_endpoint(), io_ec);
}
else
{
Expand Down
59 changes: 20 additions & 39 deletions include/boost/mysql/detail/engine_stream_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,61 +31,48 @@ namespace mysql {
namespace detail {

// Connect and close helpers
template <class Stream, class = void>
struct endpoint_storage // prevent build errors for non socket streams
{
void store(const void*) { BOOST_ASSERT(false); } // LCOV_EXCL_LINE
};

template <class Stream>
struct endpoint_storage<Stream, void_t<typename Stream::lowest_layer_type::endpoint_type>>
{
using endpoint_type = typename Stream::lowest_layer_type::endpoint_type;
endpoint_type value;
void store(const void* v) { value = *static_cast<const endpoint_type*>(v); }
};

// LCOV_EXCL_START
template <class Stream>
void do_connect_impl(Stream&, const endpoint_storage<Stream>&, error_code&, std::false_type)
void do_connect_impl(Stream&, const void*, error_code&, std::false_type)
{
BOOST_ASSERT(false);
}
// LCOV_EXCL_STOP

template <class Stream>
void do_connect_impl(Stream& stream, const endpoint_storage<Stream>& ep, error_code& ec, std::true_type)
void do_connect_impl(Stream& stream, const void* ep, error_code& ec, std::true_type)
{
stream.lowest_layer().connect(ep.value, ec);
stream.lowest_layer().connect(
*static_cast<const typename Stream::lowest_layer_type::endpoint_type*>(ep),
ec
);
}

template <class Stream>
void do_connect(Stream& stream, const endpoint_storage<Stream>& ep, error_code& ec)
void do_connect(Stream& stream, const void* ep, error_code& ec)
{
do_connect_impl(stream, ep, ec, is_socket_stream<Stream>{});
}

// LCOV_EXCL_START
template <class Stream, class CompletionToken>
void do_async_connect_impl(Stream&, const endpoint_storage<Stream>&, CompletionToken&&, std::false_type)
void do_async_connect_impl(Stream&, const void*, CompletionToken&&, std::false_type)
{
BOOST_ASSERT(false);
}
// LCOV_EXCL_STOP

template <class Stream, class CompletionToken>
void do_async_connect_impl(
Stream& stream,
const endpoint_storage<Stream>& ep,
CompletionToken&& token,
std::true_type
)
void do_async_connect_impl(Stream& stream, const void* ep, CompletionToken&& token, std::true_type)
{
stream.lowest_layer().async_connect(ep.value, std::forward<CompletionToken>(token));
stream.lowest_layer().async_connect(
*static_cast<const typename Stream::lowest_layer_type::endpoint_type*>(ep),
std::forward<CompletionToken>(token)
);
}

template <class Stream, class CompletionToken>
void do_async_connect(Stream& stream, const endpoint_storage<Stream>& ep, CompletionToken&& token)
void do_async_connect(Stream& stream, const void* ep, CompletionToken&& token)
{
do_async_connect_impl(stream, ep, std::forward<CompletionToken>(token), is_socket_stream<Stream>{});
}
Expand Down Expand Up @@ -115,7 +102,6 @@ template <class Stream>
class engine_stream_adaptor
{
Stream stream_;
endpoint_storage<Stream> endpoint_;

public:
template <class... Args>
Expand All @@ -128,8 +114,6 @@ class engine_stream_adaptor

bool supports_ssl() const { return false; }

void set_endpoint(const void* val) { endpoint_.store(val); }

using executor_type = asio::any_io_executor;
executor_type get_executor() { return stream_.get_executor(); }

Expand Down Expand Up @@ -185,12 +169,12 @@ class engine_stream_adaptor
}

// Connect and close
void connect(error_code& ec) { do_connect(stream_, endpoint_, ec); }
void connect(const void* endpoint, error_code& ec) { do_connect(stream_, endpoint, ec); }

template <class CompletionToken>
void async_connect(CompletionToken&& token)
void async_connect(const void* endpoint, CompletionToken&& token)
{
do_async_connect(stream_, endpoint_, std::forward<CompletionToken>(token));
do_async_connect(stream_, endpoint, std::forward<CompletionToken>(token));
}

void close(error_code& ec) { do_close(stream_, ec); }
Expand All @@ -200,7 +184,6 @@ template <class Stream>
class engine_stream_adaptor<asio::ssl::stream<Stream>>
{
asio::ssl::stream<Stream> stream_;
endpoint_storage<asio::ssl::stream<Stream>> endpoint_;

public:
template <class... Args>
Expand All @@ -213,8 +196,6 @@ class engine_stream_adaptor<asio::ssl::stream<Stream>>

bool supports_ssl() const { return true; }

void set_endpoint(const void* val) { endpoint_.store(val); }

using executor_type = asio::any_io_executor;
executor_type get_executor() { return stream_.get_executor(); }

Expand Down Expand Up @@ -288,12 +269,12 @@ class engine_stream_adaptor<asio::ssl::stream<Stream>>
}

// Connect and close
void connect(error_code& ec) { do_connect(stream_, endpoint_, ec); }
void connect(const void* endpoint, error_code& ec) { do_connect(stream_, endpoint, ec); }

template <class CompletionToken>
void async_connect(CompletionToken&& token)
void async_connect(const void* endpoint, CompletionToken&& token)
{
do_async_connect(stream_, endpoint_, std::forward<CompletionToken>(token));
do_async_connect(stream_, endpoint, std::forward<CompletionToken>(token));
}

void close(error_code& ec) { do_close(stream_, ec); }
Expand Down
Loading

0 comments on commit 829dbf7

Please sign in to comment.