From d620008e0dfb76fd45908ddb9a6147f09ac25e3f Mon Sep 17 00:00:00 2001 From: Bruno Dutra Date: Sun, 26 Jan 2025 00:32:13 +0100 Subject: [PATCH] keep killers local to each thread --- benches/search.rs | 2 +- lib/search/control.rs | 2 +- lib/search/engine.rs | 230 ++++++++++++++++++++++++------------------ lib/search/killers.rs | 93 ++--------------- 4 files changed, 146 insertions(+), 181 deletions(-) diff --git a/benches/search.rs b/benches/search.rs index 9d0290e9..54dfbc33 100644 --- a/benches/search.rs +++ b/benches/search.rs @@ -13,7 +13,7 @@ fn bench(reps: u64, options: &Options, limits: &Limits) -> Duration { let mut time = Duration::ZERO; for _ in 0..reps { - let mut e = Engine::with_options(options); + let e = Engine::with_options(options); let stopper = Trigger::armed(); let pos = Evaluator::default(); let timer = Instant::now(); diff --git a/lib/search/control.rs b/lib/search/control.rs index de4f8ef2..ef7f465b 100644 --- a/lib/search/control.rs +++ b/lib/search/control.rs @@ -7,7 +7,7 @@ use derive_more::{Display, Error}; pub struct Interrupted; /// The search control. -#[derive(Debug, Default)] +#[derive(Debug, Default, Copy, Clone)] pub enum Control<'a> { #[default] Unlimited, diff --git a/lib/search/engine.rs b/lib/search/engine.rs index 68c8ca73..8955db22 100644 --- a/lib/search/engine.rs +++ b/lib/search/engine.rs @@ -1,55 +1,38 @@ -use crate::chess::{Move, Outcome}; +use crate::chess::{Move, Outcome, Position}; use crate::nnue::{Evaluator, Value}; use crate::search::*; -use crate::util::{Assume, Counter, Integer, Timer, Trigger}; +use crate::util::{AlignTo64, Assume, Counter, Integer, Timer, Trigger}; use arrayvec::ArrayVec; -use std::{ops::Range, thread, time::Duration}; - -#[cfg(test)] -use crate::search::{HashSize, ThreadCount}; +use derive_more::Deref; +use std::{mem::swap, ops::Range, thread, time::Duration}; #[cfg(test)] use proptest::strategy::LazyJust; /// A chess engine. -#[derive(Debug)] -#[cfg_attr(test, derive(test_strategy::Arbitrary))] -pub struct Engine { - threads: ThreadCount, - #[cfg_attr(test, map(|s: HashSize| TranspositionTable::new(s)))] - tt: TranspositionTable, - #[cfg_attr(test, strategy(LazyJust::new(Killers::default)))] - killers: Killers, - #[cfg_attr(test, strategy(LazyJust::new(History::default)))] - history: History, +#[derive(Debug, Clone, Deref)] +pub struct Search<'a> { + #[deref] + engine: &'a Engine, + ctrl: Control<'a>, + killers: AlignTo64<[Killers; Ply::MAX as usize + 1]>, } -impl Default for Engine { - fn default() -> Self { - Self::new() - } -} - -impl Engine { - /// Initializes the engine with the default [`Options`]. - pub fn new() -> Self { - Self::with_options(&Options::default()) - } +impl<'a> Search<'a> { + fn new(engine: &'a Engine, ctrl: Control<'a>) -> Self { + let killers = AlignTo64([Killers::default(); Ply::MAX as usize + 1]); - /// Initializes the engine with the given [`Options`]. - pub fn with_options(options: &Options) -> Self { - Engine { - threads: options.threads, - tt: TranspositionTable::new(options.hash), - killers: Killers::default(), - history: History::default(), + Search { + engine, + ctrl, + killers, } } #[allow(clippy::too_many_arguments)] fn record( - &self, - pos: &Evaluator, + &mut self, + pos: &Position, moves: &[(Move, Value)], bounds: Range, depth: Depth, @@ -60,7 +43,7 @@ impl Engine { let draft = depth - ply; if score >= bounds.end { if best.is_quiet() { - self.killers.insert(pos, ply, best); + self.killers[ply.cast::()].insert(best); } self.history.update(pos, best, draft.get()); @@ -154,46 +137,53 @@ impl Engine { /// /// [alpha-beta]: https://www.chessprogramming.org/Alpha-Beta fn ab( - &self, + &mut self, pos: &Evaluator, bounds: Range, depth: Depth, ply: Ply, - ctrl: &Control, ) -> Result, Interrupted> { if ply.cast::() < N && depth > ply && bounds.start + 1 < bounds.end { - self.pvs(pos, bounds, depth, ply, ctrl) + self.pvs(pos, bounds, depth, ply) } else { - Ok(self.pvs::<0>(pos, bounds, depth, ply, ctrl)?.convert()) + Ok(self.pvs::<0>(pos, bounds, depth, ply)?.convert()) } } + /// The full-window alpha-beta search. + fn fw( + &mut self, + pos: &Evaluator, + depth: Depth, + ply: Ply, + ) -> Result, Interrupted> { + self.ab(pos, Score::lower()..Score::upper(), depth, ply) + } + /// The [zero-window] alpha-beta search. /// /// [zero-window]: https://www.chessprogramming.org/Null_Window fn nw( - &self, + &mut self, pos: &Evaluator, beta: Score, depth: Depth, ply: Ply, - ctrl: &Control, ) -> Result, Interrupted> { - self.ab(pos, beta - 1..beta, depth, ply, ctrl) + self.ab(pos, beta - 1..beta, depth, ply) } /// An implementation of the [PVS] variation of the alpha-beta search. /// /// [PVS]: https://www.chessprogramming.org/Principal_Variation_Search fn pvs( - &self, + &mut self, pos: &Evaluator, bounds: Range, depth: Depth, ply: Ply, - ctrl: &Control, ) -> Result, Interrupted> { - ctrl.interrupted()?; + self.ctrl.interrupted()?; let is_root = ply == 0; (bounds.start < bounds.end).assume(); let (alpha, beta) = match pos.outcome() { @@ -275,7 +265,7 @@ impl Engine { let mut next = pos.clone(); next.pass(); self.tt.prefetch(next.zobrist()); - if -self.nw::<0>(&next, -beta + 1, d + ply, ply + 1, ctrl)? >= beta { + if -self.nw::<0>(&next, -beta + 1, d + ply, ply + 1)? >= beta { #[cfg(not(test))] // The null move pruning heuristic is not exact. return Ok(transposed.convert()); @@ -284,7 +274,7 @@ impl Engine { } } - let killers = self.killers.get(pos, ply); + let killer = self.killers[ply.cast::()]; let mut moves: ArrayVec<_, 255> = pos .moves() .filter(|ms| !quiesce || !ms.is_quiet()) @@ -292,7 +282,7 @@ impl Engine { .map(|m| { if Some(m) == transposed.moves().next() { return (m, Value::upper()); - } else if killers.contains(m) { + } else if killer.contains(m) { return (m, Value::new(128)); } @@ -315,7 +305,7 @@ impl Engine { let mut next = pos.clone(); next.play(*m); self.tt.prefetch(next.zobrist()); - if -self.nw::<0>(&next, -beta + 1, d + ply, ply + 1, ctrl)? >= beta { + if -self.nw::<0>(&next, -beta + 1, d + ply, ply + 1)? >= beta { #[cfg(not(test))] // The multi-cut pruning heuristic is not exact. return Ok(transposed.convert()); @@ -337,7 +327,7 @@ impl Engine { let mut next = pos.clone(); next.play(m); self.tt.prefetch(next.zobrist()); - (m, -self.ab(&next, -beta..-alpha, depth, ply + 1, ctrl)?) + (m, -self.ab(&next, -beta..-alpha, depth, ply + 1)?) } }; @@ -372,9 +362,9 @@ impl Engine { _ => 0, }; - let partial = match -self.nw(&next, -alpha, depth - lmr, ply + 1, ctrl)? { + let partial = match -self.nw(&next, -alpha, depth - lmr, ply + 1)? { partial if partial <= alpha || (partial >= beta && lmr <= 0) => partial, - _ => -self.ab(&next, -beta..-alpha, depth, ply + 1, ctrl)?, + _ => -self.ab(&next, -beta..-alpha, depth, ply + 1)?, }; if partial > tail { @@ -391,17 +381,20 @@ impl Engine { /// [aspiration windows]: https://www.chessprogramming.org/Aspiration_Windows /// [iterative deepening]: https://www.chessprogramming.org/Iterative_Deepening fn aw( - &self, + &mut self, pos: &Evaluator, limit: Depth, - nodes: &Counter, - time: &Range, - stopper: &Trigger, + time: Range, ) -> Pv { - let timer = Timer::new(time.end); - let mut pv = Pv::new(Score::lower(), []); - - 'id: for depth in Depth::iter() { + let mut ctrl = Control::Unlimited; + swap(&mut self.ctrl, &mut ctrl); + self.fw::<0>(pos, Depth::new(0), Ply::new(0)).assume(); + let mut pv = self.fw(pos, Depth::new(1), Ply::new(0)).assume(); + swap(&mut self.ctrl, &mut ctrl); + + let mut depth = Depth::new(1); + 'id: while depth < limit { + depth = depth + 1; let mut draft = depth; let mut delta = 5i16; @@ -410,21 +403,13 @@ impl Engine { _ => (pv.score() - delta, pv.score() + delta), }; - let ctrl = if N > 0 && pv.moves().next().is_none() { - Control::Unlimited - } else if depth < limit { - Control::Limited(nodes, &timer, stopper) - } else { - break 'id; - }; - 'aw: loop { delta = delta.saturating_mul(2); - if ctrl.timer().remaining() < Some(time.end - time.start) { + if self.ctrl.timer().remaining() < Some(time.end - time.start) { break 'id; } - let Ok(partial) = self.ab(pos, lower..upper, draft, Ply::new(0), &ctrl) else { + let Ok(partial) = self.ab(pos, lower..upper, draft, Ply::new(0)) else { break 'id; }; @@ -456,8 +441,41 @@ impl Engine { pv } +} + +/// A chess engine. +#[derive(Debug)] +#[cfg_attr(test, derive(test_strategy::Arbitrary))] +pub struct Engine { + threads: ThreadCount, + #[cfg_attr(test, map(|s: HashSize| TranspositionTable::new(s)))] + tt: TranspositionTable, + #[cfg_attr(test, strategy(LazyJust::new(History::default)))] + history: History, +} + +impl Default for Engine { + fn default() -> Self { + Self::new() + } +} + +impl Engine { + /// Initializes the engine with the default [`Options`]. + pub fn new() -> Self { + Self::with_options(&Options::default()) + } + + /// Initializes the engine with the given [`Options`]. + pub fn with_options(options: &Options) -> Self { + Engine { + threads: options.threads, + tt: TranspositionTable::new(options.hash), + history: History::default(), + } + } - fn time_to_search(&self, pos: &Evaluator, limits: &Limits) -> Range { + fn time_to_search(&self, pos: &Position, limits: &Limits) -> Range { let (clock, inc) = match limits { Limits::Clock(c, i) => (c, i), _ => return limits.time()..limits.time(), @@ -470,16 +488,21 @@ impl Engine { } /// Searches for the [principal variation][`Pv`]. - pub fn search(&mut self, pos: &Evaluator, limits: &Limits, stopper: &Trigger) -> Pv { - let nodes = Counter::new(limits.nodes()); + pub fn search(&self, pos: &Evaluator, limits: &Limits, stopper: &Trigger) -> Pv { let time = self.time_to_search(pos, limits); + let nodes = Counter::new(limits.nodes()); + let timer = Timer::new(time.end); + let ctrl = Control::Limited(&nodes, &timer, stopper); + let mut search = Search::new(self, ctrl); thread::scope(|s| { for _ in 1..self.threads.get() { - s.spawn(|| self.aw::<0>(pos, limits.depth(), &nodes, &time, stopper)); + let time = time.clone(); + let mut search = search.clone(); + s.spawn(move || search.aw::<0>(pos, limits.depth(), time)); } - let pv = self.aw(pos, limits.depth(), &nodes, &time, stopper); + let pv = search.aw(pos, limits.depth(), time); stopper.disarm(); pv }) @@ -544,10 +567,10 @@ mod tests { #[filter(#s.mate().is_none() && #s >= #b)] s: Score, #[map(|s: Selector| s.select(#pos.moves().flatten()))] m: Move, ) { - use Control::Unlimited; let tpos = Transposition::new(ScoreBound::Lower(s), d, m); e.tt.set(pos.zobrist(), tpos); - assert_eq!(e.nw::<1>(&pos, b, d, p, &Unlimited), Ok(Pv::new(s, []))); + let mut search = Search::new(&e, Control::Unlimited); + assert_eq!(search.nw::<1>(&pos, b, d, p), Ok(Pv::new(s, []))); } #[proptest] @@ -562,10 +585,10 @@ mod tests { #[filter(#s.mate().is_none() && #s < #b)] s: Score, #[map(|s: Selector| s.select(#pos.moves().flatten()))] m: Move, ) { - use Control::Unlimited; let tpos = Transposition::new(ScoreBound::Upper(s), d, m); e.tt.set(pos.zobrist(), tpos); - assert_eq!(e.nw::<1>(&pos, b, d, p, &Unlimited), Ok(Pv::new(s, []))); + let mut search = Search::new(&e, Control::Unlimited); + assert_eq!(search.nw::<1>(&pos, b, d, p), Ok(Pv::new(s, []))); } #[proptest] @@ -580,10 +603,10 @@ mod tests { #[filter(#s.mate().is_none())] s: Score, #[map(|s: Selector| s.select(#pos.moves().flatten()))] m: Move, ) { - use Control::Unlimited; let tpos = Transposition::new(ScoreBound::Exact(s), d, m); e.tt.set(pos.zobrist(), tpos); - assert_eq!(e.nw::<1>(&pos, b, d, p, &Unlimited), Ok(Pv::new(s, []))); + let mut search = Search::new(&e, Control::Unlimited); + assert_eq!(search.nw::<1>(&pos, b, d, p), Ok(Pv::new(s, []))); } #[proptest] @@ -594,8 +617,10 @@ mod tests { d: Depth, #[filter(#p > 0)] p: Ply, ) { + let mut search = Search::new(&e, Control::Unlimited); + assert_eq!( - e.nw::<1>(&pos, b, d, p, &Control::Unlimited)? < b, + search.nw::<1>(&pos, b, d, p)? < b, alphabeta(&pos, b - 1..b, d, p) < b ); } @@ -612,7 +637,8 @@ mod tests { let timer = Timer::infinite(); let trigger = Trigger::armed(); let ctrl = Control::Limited(&nodes, &timer, &trigger); - assert_eq!(e.ab::<1>(&pos, b, d, p, &ctrl), Err(Interrupted)); + let mut search = Search::new(&e, ctrl); + assert_eq!(search.ab::<1>(&pos, b, d, p), Err(Interrupted)); } #[proptest] @@ -627,8 +653,9 @@ mod tests { let timer = Timer::new(Duration::ZERO); let trigger = Trigger::armed(); let ctrl = Control::Limited(&nodes, &timer, &trigger); + let mut search = Search::new(&e, ctrl); std::thread::sleep(Duration::from_millis(1)); - assert_eq!(e.ab::<1>(&pos, b, d, p, &ctrl), Err(Interrupted)); + assert_eq!(search.ab::<1>(&pos, b, d, p), Err(Interrupted)); } #[proptest] @@ -643,7 +670,8 @@ mod tests { let timer = Timer::infinite(); let trigger = Trigger::disarmed(); let ctrl = Control::Limited(&nodes, &timer, &trigger); - assert_eq!(e.ab::<1>(&pos, b, d, p, &ctrl), Err(Interrupted)); + let mut search = Search::new(&e, ctrl); + assert_eq!(search.ab::<1>(&pos, b, d, p), Err(Interrupted)); } #[proptest] @@ -653,8 +681,10 @@ mod tests { #[filter(!#b.is_empty())] b: Range, d: Depth, ) { + let mut search = Search::new(&e, Control::Unlimited); + assert_eq!( - e.ab::<1>(&pos, b, d, Ply::upper(), &Control::Unlimited), + search.ab::<1>(&pos, b, d, Ply::upper()), Ok(Pv::new(pos.evaluate().saturate(), [])) ); } @@ -667,8 +697,10 @@ mod tests { d: Depth, #[filter(#p > 0 || #pos.outcome() != Some(Outcome::DrawByThreefoldRepetition))] p: Ply, ) { + let mut search = Search::new(&e, Control::Unlimited); + assert_eq!( - e.ab::<1>(&pos, b, d, p, &Control::Unlimited), + search.ab::<1>(&pos, b, d, p), Ok(Pv::new(Score::new(0), [])) ); } @@ -681,25 +713,31 @@ mod tests { d: Depth, #[filter(#p > 0)] p: Ply, ) { + let mut search = Search::new(&e, Control::Unlimited); + assert_eq!( - e.ab::<1>(&pos, b, d, p, &Control::Unlimited), + search.ab::<1>(&pos, b, d, p), Ok(Pv::new(Score::mated(p), [])) ); } #[proptest] - fn search_finds_the_minimax_score(mut e: Engine, pos: Evaluator, #[filter(#d > 1)] d: Depth) { + fn search_finds_the_minimax_score(e: Engine, pos: Evaluator, #[filter(#d > 1)] d: Depth) { let time = Duration::MAX..Duration::MAX; let nodes = Counter::new(u64::MAX); + let timer = Timer::new(time.end); + let trigger = Trigger::armed(); + let ctrl = Control::Limited(&nodes, &timer, &trigger); + let mut search = Search::new(&e, ctrl); assert_eq!( e.search(&pos, &Limits::Depth(d), &Trigger::armed()).score(), - e.aw::<0>(&pos, d, &nodes, &time, &Trigger::armed()).score() + search.aw::<0>(&pos, d, time).score() ); } #[proptest] - fn search_is_stable(mut e: Engine, pos: Evaluator, d: Depth) { + fn search_is_stable(e: Engine, pos: Evaluator, d: Depth) { let limits = Limits::Depth(d); assert_eq!( @@ -710,7 +748,7 @@ mod tests { #[proptest] fn search_extends_time_to_find_some_pv( - mut e: Engine, + e: Engine, #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Duration::ZERO.into(); @@ -720,7 +758,7 @@ mod tests { #[proptest] fn search_extends_depth_to_find_some_pv( - mut e: Engine, + e: Engine, #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Depth::lower().into(); @@ -730,7 +768,7 @@ mod tests { #[proptest] fn search_ignores_stopper_to_find_some_pv( - mut e: Engine, + e: Engine, #[filter(#pos.outcome().is_none())] pos: Evaluator, ) { let limits = Limits::None; diff --git a/lib/search/killers.rs b/lib/search/killers.rs index d6440b6d..8d1889d5 100644 --- a/lib/search/killers.rs +++ b/lib/search/killers.rs @@ -1,19 +1,14 @@ -use crate::chess::{Move, Position}; -use crate::search::Ply; -use crate::util::{AlignTo64, Assume, Binary, Bits, Integer}; -use derive_more::Debug; -use std::mem::{size_of, MaybeUninit}; -use std::sync::atomic::{AtomicU32, Ordering::Relaxed}; +use crate::chess::Move; -/// A pair of [killer moves]. +/// A set of [killer moves]. /// /// [killer moves]: https://www.chessprogramming.org/Killer_Move #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)] #[cfg_attr(test, derive(test_strategy::Arbitrary))] -pub struct Killer(Option, Option); +pub struct Killers(Option, Option); -impl Killer { - /// Adds a killer move to the pair. +impl Killers { + /// Adds a killer move to the set. #[inline(always)] pub fn insert(&mut self, m: Move) { if self.0 != Some(m) { @@ -22,79 +17,23 @@ impl Killer { } } - /// Whether a move is a killer. + /// Whether a move is in the set. #[inline(always)] pub fn contains(&self, m: Move) -> bool { self.0 == Some(m) || self.1 == Some(m) } } -impl Binary for Killer { - type Bits = Bits as Binary>::Bits::BITS }>; - - #[inline(always)] - fn encode(&self) -> Self::Bits { - let mut bits = Bits::default(); - bits.push(self.1.encode()); - bits.push(self.0.encode()); - bits - } - - #[inline(always)] - fn decode(mut bits: Self::Bits) -> Self { - Killer(Binary::decode(bits.pop()), Binary::decode(bits.pop())) - } -} - -/// A set of [killer moves] indexed by [`Ply`] and side to move. -/// -/// [killer moves]: https://www.chessprogramming.org/Killer_Move -#[derive(Debug)] -#[debug("Killers({})", size_of::())] -pub struct Killers(AlignTo64<[[AtomicU32; 2]; Ply::MAX as usize]>); - -impl Default for Killers { - #[inline(always)] - fn default() -> Self { - Killers(unsafe { MaybeUninit::zeroed().assume_init() }) - } -} - -impl Killers { - /// Adds a killer move to the set at a given ply for a given side to move. - #[inline(always)] - pub fn insert(&self, pos: &Position, ply: Ply, m: Move) { - let slot = &self.0.get(ply.cast::()).assume()[pos.turn().cast::()]; - let mut killer = Killer::decode(Bits::new(slot.load(Relaxed))); - killer.insert(m); - slot.store(killer.encode().get(), Relaxed); - } - - /// Returns the known killer moves at a given ply for a given side to move. - #[inline(always)] - pub fn get(&self, pos: &Position, ply: Ply) -> Killer { - let slot = &self.0.get(ply.cast::()).assume()[pos.turn().cast::()]; - Killer::decode(Bits::new(slot.load(Relaxed))) - } -} - #[cfg(test)] mod tests { use super::*; - use crate::util::Integer; use proptest::sample::size_range; use std::collections::HashSet; - use std::fmt::Debug; use test_strategy::proptest; - #[proptest] - fn decoding_encoded_killer(k: Killer) { - assert_eq!(Killer::decode(k.encode()), k); - } - #[proptest] fn contains_returns_true_only_if_inserted(m: Move) { - let mut k = Killer::default(); + let mut k = Killers::default(); assert!(!k.contains(m)); k.insert(m); assert!(k.contains(m)); @@ -102,17 +41,17 @@ mod tests { #[proptest] fn insert_avoids_duplicated_moves(m: Move) { - let mut k = Killer::default(); + let mut k = Killers::default(); k.insert(m); k.insert(m); - assert_eq!(k, Killer(Some(m), None)); + assert_eq!(k, Killers(Some(m), None)); } #[proptest] fn insert_keeps_most_recent(#[any(size_range(2..10).lift())] ms: HashSet, m: Move) { - let mut k = Killer::default(); + let mut k = Killers::default(); for m in ms { k.insert(m); @@ -121,16 +60,4 @@ mod tests { k.insert(m); assert_eq!(k.0, Some(m)); } - - #[proptest] - fn get_turns_killers_at_ply_for_the_side_to_move( - pos: Position, - #[filter((0..Ply::MAX).contains(&#p.get()))] p: Ply, - m: Move, - ) { - let ks = Killers::default(); - ks.insert(&pos, p, m); - let k = ks.get(&pos, p); - assert_eq!(k.0, Some(m)); - } }