diff --git a/src/locking.rs b/src/locking.rs index aa019a1..f0499f8 100644 --- a/src/locking.rs +++ b/src/locking.rs @@ -40,21 +40,55 @@ impl<'a, T> LockGuard<'a, T> { #[cfg(test)] mod tests { use super::LockGuard; - use std::sync::RwLock; - - static MY_MUTEX: RwLock<()> = RwLock::new(()); + use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Barrier, RwLock, + }, + thread, + }; #[test] - fn test_shared() { - let guard = LockGuard::acquire(&MY_MUTEX, false); - assert!(!guard.is_exclusive()); - assert!(matches!(guard, LockGuard::Shared(_read_guard))); + fn test_lock() { + static TEST_RWLOCK: RwLock<()> = RwLock::new(()); + { + let guard = LockGuard::acquire(&TEST_RWLOCK, false); + assert!(!guard.is_exclusive()); + assert!(matches!(guard, LockGuard::Shared(_read_guard))); + } + { + let guard = LockGuard::acquire(&TEST_RWLOCK, true); + assert!(guard.is_exclusive()); + assert!(matches!(guard, LockGuard::Exclusive(_write_guard))); + } } #[test] - fn test_exclusive() { - let guard = LockGuard::acquire(&MY_MUTEX, true); - assert!(guard.is_exclusive()); - assert!(matches!(guard, LockGuard::Exclusive(_write_guard))); + fn test_threads() { + static TEST_RWLOCK: RwLock<()> = RwLock::new(()); + const THREAD_COUNT: usize = 32usize; + let barrier = Arc::new(Barrier::new(THREAD_COUNT)); + let mut thread_list = VecDeque::with_capacity(THREAD_COUNT); + let counter = Arc::new(AtomicUsize::new(0usize)); + for _tid in 0..THREAD_COUNT { + let thread_barrier = Arc::clone(&barrier); + let thread_counter = Arc::clone(&counter); + thread_list.push_back(thread::spawn(move || { + let is_leader = thread_barrier.wait().is_leader(); + for _iteration in 0..100000 { + let guard = LockGuard::acquire(&TEST_RWLOCK, is_leader); + let value = thread_counter.fetch_add(1usize, Ordering::Relaxed); + if guard.is_exclusive() { + assert_eq!(value, 0usize, "Invalid counter value!"); + } + thread::yield_now(); + thread_counter.fetch_sub(1usize, Ordering::Relaxed); + } + })); + } + for thread in thread_list.drain(..) { + thread.join().unwrap(); + } } }