-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3809 from pleroy/PushPullCallback
Helper classes to have C++ code effectively call into C# in a way that is compatible with journaling
- Loading branch information
Showing
5 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#pragma once | ||
|
||
#include <functional> | ||
#include <tuple> | ||
|
||
#include "absl/synchronization/mutex.h" | ||
#include "base/macros.hpp" // 🧙 For GUARDED_BY. | ||
|
||
namespace principia { | ||
namespace base { | ||
namespace _push_pull_callback { | ||
namespace internal { | ||
|
||
// A helper class for callbacks from unmanaged code (C++) to managed code (C#). | ||
// While we could support true callbacks in the generator, it is difficult to | ||
// implement journalling because a journal entry would have to include all the | ||
// parameters passed to callbacks and the returned results. | ||
// Instead we invert the flow of control (just like we do for serialization) and | ||
// expect the managed code to pull the arguments and push the result of the | ||
// computations that it is executing. | ||
// Note that calling the managed and unmanaged APIs from the same thread will | ||
// inevitably cause deadlocks. See |PushPullExecutor| below for a solution to | ||
// this. | ||
template<typename Result, typename... Arguments> | ||
class PushPullCallback { | ||
public: | ||
// The managed API, called to extract the arguments for the unmanaged callback | ||
// and return its result. |Pull| returns false if there are no more arguments | ||
// to be processed and the managed code should stop its iteration. | ||
bool Pull(Arguments&... arguments); | ||
void Push(Result result); | ||
|
||
// Used on the unmanaged side to use this object in a context that requires a | ||
// function. | ||
std::function<Result(Arguments...)> ToStdFunction(); | ||
|
||
// Used on the unmanaged side to indicate that the computation has finished. | ||
// After a call to this method, |Pull| always returns false. | ||
void Shutdown(); | ||
|
||
private: | ||
// The unmanaged API, called by the function returned by |ToStdFunction|. | ||
void Push(Arguments... arguments); | ||
Result Pull(); | ||
|
||
// These functions return |lock_| held exclusively. | ||
void WaitUntilHasArgumentsOrShuttingDownAndLock(); | ||
void WaitUntilHasResultAndLock(); | ||
|
||
absl::Mutex lock_; | ||
std::optional<std::tuple<Arguments...>> arguments_ GUARDED_BY(lock_); | ||
std::optional<Result> result_ GUARDED_BY(lock_); | ||
bool shutdown_ GUARDED_BY(lock_) = false; | ||
}; | ||
|
||
// A helper class to execute a task that takes a callback from unmanaged code to | ||
// managed code and returns a value of type |T|. The task is executed on a | ||
// separate thread, so calls to the|PushPullCallback| don't cause deadlocks. | ||
template<typename T, | ||
typename Result, typename... Arguments> | ||
class PushPullExecutor { | ||
public: | ||
using Task = std::function<T(std::function<Result(Arguments...)>)>; | ||
|
||
explicit PushPullExecutor(Task task); | ||
~PushPullExecutor(); | ||
|
||
// Returns the internal |PushPullCallback| object that is used by the managed | ||
// code to pull arguments and push results. | ||
PushPullCallback<Result, Arguments...>& callback(); | ||
|
||
// Returns the result of the task passed at construction. Must only be called | ||
// once |PushPullCallback::Pull| has indicated that the task has finished. | ||
T get(); | ||
|
||
private: | ||
PushPullCallback<Result, Arguments...> callback_; | ||
std::thread thread_; | ||
mutable absl::Mutex lock_; | ||
std::optional<T> result_ GUARDED_BY(lock_); | ||
}; | ||
|
||
} // namespace internal | ||
|
||
using internal::PushPullCallback; | ||
using internal::PushPullExecutor; | ||
|
||
} // namespace _push_pull_callback | ||
} // namespace base | ||
} // namespace principia | ||
|
||
#include "base/push_pull_callback_body.hpp" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#pragma once | ||
|
||
#include "base/push_pull_callback.hpp" | ||
|
||
#include <utility> | ||
|
||
namespace principia { | ||
namespace base { | ||
namespace _push_pull_callback { | ||
namespace internal { | ||
|
||
template<typename Result, typename... Arguments> | ||
bool PushPullCallback<Result, Arguments...>::Pull(Arguments&... arguments) { | ||
WaitUntilHasArgumentsOrShuttingDownAndLock(); | ||
if (shutdown_) { | ||
return false; | ||
} | ||
std::tie(arguments...) = std::move(arguments_.value()); | ||
arguments_.reset(); | ||
lock_.Unlock(); | ||
return true; | ||
} | ||
|
||
template<typename Result, typename... Arguments> | ||
void PushPullCallback<Result, Arguments...>::Push(Result result) { | ||
absl::MutexLock l(&lock_); | ||
result_ = std::move(result); | ||
} | ||
|
||
template<typename Result, typename... Arguments> | ||
std::function<Result(Arguments...)> | ||
PushPullCallback<Result, Arguments...>::ToStdFunction() { | ||
return [this](Arguments const&... arguments) { | ||
Push(arguments...); | ||
return Pull(); | ||
}; | ||
} | ||
|
||
template<typename Result, typename ...Arguments> | ||
void PushPullCallback<Result, Arguments...>::Shutdown() { | ||
absl::MutexLock l(&lock_); | ||
shutdown_ = true; | ||
} | ||
|
||
template<typename Result, typename... Arguments> | ||
void PushPullCallback<Result, Arguments...>::Push( | ||
Arguments... arguments) { | ||
absl::MutexLock l(&lock_); | ||
arguments_ = std::tuple(std::move(arguments)...); | ||
} | ||
|
||
template<typename Result, typename... Arguments> | ||
Result PushPullCallback<Result, Arguments...>::Pull() { | ||
WaitUntilHasResultAndLock(); | ||
lock_.AssertHeld(); | ||
Result result = result_.value(); | ||
result_.reset(); | ||
lock_.Unlock(); | ||
return std::move(result); | ||
} | ||
|
||
template<typename Result, typename... Arguments> | ||
void PushPullCallback<Result, Arguments...>:: | ||
WaitUntilHasArgumentsOrShuttingDownAndLock() { | ||
auto has_arguments_or_shutting_down = [this]() { | ||
lock_.AssertReaderHeld(); | ||
return shutdown_ || arguments_.has_value(); | ||
}; | ||
|
||
lock_.LockWhen(absl::Condition(&has_arguments_or_shutting_down)); | ||
} | ||
|
||
template<typename Result, typename... Arguments> | ||
void PushPullCallback<Result, Arguments...>::WaitUntilHasResultAndLock() { | ||
auto has_result = [this]() { | ||
lock_.AssertReaderHeld(); | ||
return result_.has_value(); | ||
}; | ||
|
||
lock_.LockWhen(absl::Condition(&has_result)); | ||
} | ||
|
||
template<typename T, typename Result, typename... Arguments> | ||
PushPullExecutor<T, Result, Arguments...>::PushPullExecutor(Task task) | ||
: thread_([this, task = std::move(task)]() { | ||
auto const result = task(callback_.ToStdFunction()); | ||
{ | ||
absl::MutexLock l(&lock_); | ||
result_ = result; | ||
} | ||
callback_.Shutdown(); | ||
}) {} | ||
|
||
template<typename T, typename Result, typename... Arguments> | ||
PushPullExecutor<T, Result, Arguments...>::~PushPullExecutor() { | ||
thread_.join(); | ||
} | ||
|
||
template<typename T, typename Result, typename... Arguments> | ||
PushPullCallback<Result, Arguments...>& | ||
PushPullExecutor<T, Result, Arguments...>::callback() { | ||
absl::MutexLock l(&lock_); | ||
return callback_; | ||
} | ||
|
||
template<typename T, typename Result, typename... Arguments> | ||
T PushPullExecutor<T, Result, Arguments...>::get() { | ||
absl::MutexLock l(&lock_); | ||
return std::move(result_.value()); | ||
} | ||
|
||
} // namespace internal | ||
} // namespace _push_pull_callback | ||
} // namespace base | ||
} // namespace principia |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include "base/push_pull_callback.hpp" | ||
|
||
#include <utility> | ||
|
||
#include "gtest/gtest.h" | ||
|
||
namespace principia { | ||
namespace base { | ||
|
||
using namespace principia::base::_push_pull_callback; | ||
|
||
TEST(PushPullCallback, Test) { | ||
// A task that does some computation, including running its callback |f| with | ||
// various arguments. | ||
auto task = [](std::function<int(int left, int right)> const& f) { | ||
double const a = f(2, 4); | ||
double const b = f(3, 5); | ||
return a - b; | ||
}; | ||
|
||
PushPullExecutor<double, int, int, int> executor(std::move(task)); | ||
auto& callback = executor.callback(); | ||
|
||
int left; | ||
int right; | ||
|
||
EXPECT_TRUE(callback.Pull(left, right)); | ||
EXPECT_EQ(2, left); | ||
EXPECT_EQ(4, right); | ||
callback.Push(left + right); | ||
|
||
EXPECT_TRUE(callback.Pull(left, right)); | ||
EXPECT_EQ(3, left); | ||
EXPECT_EQ(5, right); | ||
callback.Push(left - right); | ||
|
||
EXPECT_FALSE(callback.Pull(left, right)); | ||
EXPECT_EQ(8, executor.get()); | ||
} | ||
|
||
} // namespace base | ||
} // namespace principia |