Skip to content

Commit

Permalink
Prefer std::shared_mutex on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Feb 2, 2025
1 parent c7697d8 commit 7bfbde5
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CMakeCache.txt
CMakeFiles
Makefile
*.ninja
\.cache
\.ninja_*
Testing
build
Expand Down
8 changes: 4 additions & 4 deletions src/nanothread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct Worker {


static Pool *pool_default_inst = nullptr;
static std::mutex pool_default_lock;
static Lock pool_default_lock;
static uint32_t cached_core_count = 0;

uint32_t core_count() {
Expand Down Expand Up @@ -128,7 +128,7 @@ uint32_t pool_thread_id() {
}

Pool *pool_default() {
std::unique_lock<std::mutex> guard(pool_default_lock);
std::unique_lock<Lock> guard(pool_default_lock);

if (!pool_default_inst)
pool_default_inst = pool_create();
Expand Down Expand Up @@ -159,7 +159,7 @@ void pool_destroy(Pool *pool) {

uint32_t pool_size(Pool *pool) {
if (!pool) {
std::unique_lock<std::mutex> guard(pool_default_lock);
std::unique_lock<Lock> guard(pool_default_lock);
pool = pool_default_inst;
}

Expand All @@ -171,7 +171,7 @@ uint32_t pool_size(Pool *pool) {

void pool_set_size(Pool *pool, uint32_t size) {
if (!pool) {
std::unique_lock<std::mutex> guard(pool_default_lock);
std::unique_lock<Lock> guard(pool_default_lock);
pool = pool_default_inst;

if (!pool) {
Expand Down
4 changes: 2 additions & 2 deletions src/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ std::pair<Task *, uint32_t> TaskQueue::pop() {
}

void TaskQueue::wakeup() {
std::unique_lock<std::mutex> guard(sleep_mutex);
std::unique_lock<Lock> guard(sleep_mutex);
uint64_t value = sleep_state.load();
NT_TRACE("wakeup(): sleep_state := (%u, 0)", (uint32_t) (sleep_state >> 32) + 1);
sleep_state = (value + high_bit) & high_mask;
Expand Down Expand Up @@ -440,7 +440,7 @@ TaskQueue::pop_or_sleep(bool (*stopping_criterion)(void *), void *payload,
attempts++;

if (may_sleep && attempts >= NANOTHREAD_MAX_ATTEMPTS) {
std::unique_lock<std::mutex> guard(sleep_mutex);
std::unique_lock<Lock> guard(sleep_mutex);

uint64_t value = ++sleep_state, phase = value & high_mask;
NT_TRACE("pop_or_sleep(): falling asleep after %.2f milliseconds, "
Expand Down
9 changes: 7 additions & 2 deletions src/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@

#include <atomic>
#include <vector>
#include <mutex>
#include <condition_variable>
#include <cstring>

#if defined(_WIN32)
# include <windows.h>
# include <shared_mutex>
using Lock = std::shared_mutex; // Prefer (more efficient) shared_mutex on Windows
#else
# include <mutex>
using Lock = std::mutex;
#endif


struct Pool;

constexpr uint64_t high_bit = (uint64_t) 0x0000000100000000ull;
Expand Down Expand Up @@ -241,7 +246,7 @@ struct TaskQueue {
std::atomic<uint64_t> sleep_state;

/// Mutex protecting the fields below
std::mutex sleep_mutex;
Lock sleep_mutex;

/// Condition variable used to manage workers that are asleep
std::condition_variable sleep_cv;
Expand Down

0 comments on commit 7bfbde5

Please sign in to comment.