Skip to content

Commit

Permalink
Merge pull request #3809 from pleroy/PushPullCallback
Browse files Browse the repository at this point in the history
Helper classes to have C++ code effectively call into C# in a way that is compatible with journaling
  • Loading branch information
pleroy authored Nov 26, 2023
2 parents c9307c7 + ba90765 commit c20f032
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 0 deletions.
3 changes: 3 additions & 0 deletions base/base.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
<ClInclude Include="pull_serializer_body.hpp" />
<ClInclude Include="push_deserializer.hpp" />
<ClInclude Include="push_deserializer_body.hpp" />
<ClInclude Include="push_pull_callback.hpp" />
<ClInclude Include="push_pull_callback_body.hpp" />
<ClInclude Include="ranges.hpp" />
<ClInclude Include="ranges_body.hpp" />
<ClInclude Include="recurring_thread.hpp" />
Expand Down Expand Up @@ -105,6 +107,7 @@
<ClCompile Include="not_null_test.cpp" />
<ClCompile Include="pull_serializer_test.cpp" />
<ClCompile Include="push_deserializer_test.cpp" />
<ClCompile Include="push_pull_callback_test.cpp" />
<ClCompile Include="recurring_thread_test.cpp" />
<ClCompile Include="thread_pool_test.cpp" />
<ClCompile Include="version.generated.cc" />
Expand Down
9 changes: 9 additions & 0 deletions base/base.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@
<ClInclude Include="for_all_of_body.hpp">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="push_pull_callback.hpp">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="push_pull_callback_body.hpp">
<Filter>Source Files</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="not_null_test.cpp">
Expand Down Expand Up @@ -292,5 +298,8 @@
<ClCompile Include="for_all_of_test.cpp">
<Filter>Test Files</Filter>
</ClCompile>
<ClCompile Include="push_pull_callback_test.cpp">
<Filter>Test Files</Filter>
</ClCompile>
</ItemGroup>
</Project>
92 changes: 92 additions & 0 deletions base/push_pull_callback.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#pragma once

#include <functional>
#include <tuple>

#include "absl/synchronization/mutex.h"
#include "base/macros.hpp" // 🧙 For GUARDED_BY.

namespace principia {
namespace base {
namespace _push_pull_callback {
namespace internal {

// A helper class for callbacks from unmanaged code (C++) to managed code (C#).
// While we could support true callbacks in the generator, it is difficult to
// implement journalling because a journal entry would have to include all the
// parameters passed to callbacks and the returned results.
// Instead we invert the flow of control (just like we do for serialization) and
// expect the managed code to pull the arguments and push the result of the
// computations that it is executing.
// Note that calling the managed and unmanaged APIs from the same thread will
// inevitably cause deadlocks. See |PushPullExecutor| below for a solution to
// this.
template<typename Result, typename... Arguments>
class PushPullCallback {
public:
// The managed API, called to extract the arguments for the unmanaged callback
// and return its result. |Pull| returns false if there are no more arguments
// to be processed and the managed code should stop its iteration.
bool Pull(Arguments&... arguments);
void Push(Result result);

// Used on the unmanaged side to use this object in a context that requires a
// function.
std::function<Result(Arguments...)> ToStdFunction();

// Used on the unmanaged side to indicate that the computation has finished.
// After a call to this method, |Pull| always returns false.
void Shutdown();

private:
// The unmanaged API, called by the function returned by |ToStdFunction|.
void Push(Arguments... arguments);
Result Pull();

// These functions return |lock_| held exclusively.
void WaitUntilHasArgumentsOrShuttingDownAndLock();
void WaitUntilHasResultAndLock();

absl::Mutex lock_;
std::optional<std::tuple<Arguments...>> arguments_ GUARDED_BY(lock_);
std::optional<Result> result_ GUARDED_BY(lock_);
bool shutdown_ GUARDED_BY(lock_) = false;
};

// A helper class to execute a task that takes a callback from unmanaged code to
// managed code and returns a value of type |T|. The task is executed on a
// separate thread, so calls to the|PushPullCallback| don't cause deadlocks.
template<typename T,
typename Result, typename... Arguments>
class PushPullExecutor {
public:
using Task = std::function<T(std::function<Result(Arguments...)>)>;

explicit PushPullExecutor(Task task);
~PushPullExecutor();

// Returns the internal |PushPullCallback| object that is used by the managed
// code to pull arguments and push results.
PushPullCallback<Result, Arguments...>& callback();

// Returns the result of the task passed at construction. Must only be called
// once |PushPullCallback::Pull| has indicated that the task has finished.
T get();

private:
PushPullCallback<Result, Arguments...> callback_;
std::thread thread_;
mutable absl::Mutex lock_;
std::optional<T> result_ GUARDED_BY(lock_);
};

} // namespace internal

using internal::PushPullCallback;
using internal::PushPullExecutor;

} // namespace _push_pull_callback
} // namespace base
} // namespace principia

#include "base/push_pull_callback_body.hpp"
115 changes: 115 additions & 0 deletions base/push_pull_callback_body.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#pragma once

#include "base/push_pull_callback.hpp"

#include <utility>

namespace principia {
namespace base {
namespace _push_pull_callback {
namespace internal {

template<typename Result, typename... Arguments>
bool PushPullCallback<Result, Arguments...>::Pull(Arguments&... arguments) {
WaitUntilHasArgumentsOrShuttingDownAndLock();
if (shutdown_) {
return false;
}
std::tie(arguments...) = std::move(arguments_.value());
arguments_.reset();
lock_.Unlock();
return true;
}

template<typename Result, typename... Arguments>
void PushPullCallback<Result, Arguments...>::Push(Result result) {
absl::MutexLock l(&lock_);
result_ = std::move(result);
}

template<typename Result, typename... Arguments>
std::function<Result(Arguments...)>
PushPullCallback<Result, Arguments...>::ToStdFunction() {
return [this](Arguments const&... arguments) {
Push(arguments...);
return Pull();
};
}

template<typename Result, typename ...Arguments>
void PushPullCallback<Result, Arguments...>::Shutdown() {
absl::MutexLock l(&lock_);
shutdown_ = true;
}

template<typename Result, typename... Arguments>
void PushPullCallback<Result, Arguments...>::Push(
Arguments... arguments) {
absl::MutexLock l(&lock_);
arguments_ = std::tuple(std::move(arguments)...);
}

template<typename Result, typename... Arguments>
Result PushPullCallback<Result, Arguments...>::Pull() {
WaitUntilHasResultAndLock();
lock_.AssertHeld();
Result result = result_.value();
result_.reset();
lock_.Unlock();
return std::move(result);
}

template<typename Result, typename... Arguments>
void PushPullCallback<Result, Arguments...>::
WaitUntilHasArgumentsOrShuttingDownAndLock() {
auto has_arguments_or_shutting_down = [this]() {
lock_.AssertReaderHeld();
return shutdown_ || arguments_.has_value();
};

lock_.LockWhen(absl::Condition(&has_arguments_or_shutting_down));
}

template<typename Result, typename... Arguments>
void PushPullCallback<Result, Arguments...>::WaitUntilHasResultAndLock() {
auto has_result = [this]() {
lock_.AssertReaderHeld();
return result_.has_value();
};

lock_.LockWhen(absl::Condition(&has_result));
}

template<typename T, typename Result, typename... Arguments>
PushPullExecutor<T, Result, Arguments...>::PushPullExecutor(Task task)
: thread_([this, task = std::move(task)]() {
auto const result = task(callback_.ToStdFunction());
{
absl::MutexLock l(&lock_);
result_ = result;
}
callback_.Shutdown();
}) {}

template<typename T, typename Result, typename... Arguments>
PushPullExecutor<T, Result, Arguments...>::~PushPullExecutor() {
thread_.join();
}

template<typename T, typename Result, typename... Arguments>
PushPullCallback<Result, Arguments...>&
PushPullExecutor<T, Result, Arguments...>::callback() {
absl::MutexLock l(&lock_);
return callback_;
}

template<typename T, typename Result, typename... Arguments>
T PushPullExecutor<T, Result, Arguments...>::get() {
absl::MutexLock l(&lock_);
return std::move(result_.value());
}

} // namespace internal
} // namespace _push_pull_callback
} // namespace base
} // namespace principia
42 changes: 42 additions & 0 deletions base/push_pull_callback_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "base/push_pull_callback.hpp"

#include <utility>

#include "gtest/gtest.h"

namespace principia {
namespace base {

using namespace principia::base::_push_pull_callback;

TEST(PushPullCallback, Test) {
// A task that does some computation, including running its callback |f| with
// various arguments.
auto task = [](std::function<int(int left, int right)> const& f) {
double const a = f(2, 4);
double const b = f(3, 5);
return a - b;
};

PushPullExecutor<double, int, int, int> executor(std::move(task));
auto& callback = executor.callback();

int left;
int right;

EXPECT_TRUE(callback.Pull(left, right));
EXPECT_EQ(2, left);
EXPECT_EQ(4, right);
callback.Push(left + right);

EXPECT_TRUE(callback.Pull(left, right));
EXPECT_EQ(3, left);
EXPECT_EQ(5, right);
callback.Push(left - right);

EXPECT_FALSE(callback.Pull(left, right));
EXPECT_EQ(8, executor.get());
}

} // namespace base
} // namespace principia

0 comments on commit c20f032

Please sign in to comment.