Skip to content

Commit

Permalink
implement history heuristics for move ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
brunocodutra committed Dec 23, 2024
1 parent 21bffb2 commit 6c05948
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 30 deletions.
2 changes: 2 additions & 0 deletions lib/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod control;
mod depth;
mod driver;
mod engine;
mod history;
mod killers;
mod limits;
mod options;
Expand All @@ -14,6 +15,7 @@ pub use control::*;
pub use depth::*;
pub use driver::*;
pub use engine::*;
pub use history::*;
pub use killers::*;
pub use limits::*;
pub use options::*;
Expand Down
31 changes: 22 additions & 9 deletions lib/search/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub struct Engine {
tt: TranspositionTable,
#[cfg_attr(test, strategy(LazyJust::new(Killers::default)))]
killers: Killers,
#[cfg_attr(test, strategy(LazyJust::new(History::default)))]
history: History,
}

impl Default for Engine {
Expand All @@ -41,31 +43,42 @@ impl Engine {
driver: Driver::new(options.threads),
tt: TranspositionTable::new(options.hash),
killers: Killers::default(),
history: History::default(),
}
}

/// Records a `[Transposition`].
#[allow(clippy::too_many_arguments)]
fn record(
&self,
pos: &Evaluator,
moves: &[(Move, Value)],
bounds: Range<Score>,
depth: Depth,
ply: Ply,
best: Move,
score: Score,
) {
let draft = depth - ply;
if score >= bounds.end && best.is_quiet() {
self.killers.insert(ply, pos.turn(), best);
self.history.update(best, pos.turn(), draft.get());
for &(m, _) in moves.iter().rev() {
if m == best {
break;
} else if m.is_quiet() {
self.history.update(m, pos.turn(), -draft.get());
}
}
}

self.tt.set(
pos.zobrist(),
if score >= bounds.end {
Transposition::lower(depth - ply, score.normalize(-ply), best)
Transposition::lower(draft, score.normalize(-ply), best)
} else if score <= bounds.start {
Transposition::upper(depth - ply, score.normalize(-ply), best)
Transposition::upper(draft, score.normalize(-ply), best)
} else {
Transposition::exact(depth - ply, score.normalize(-ply), best)
Transposition::exact(draft, score.normalize(-ply), best)
},
);
}
Expand Down Expand Up @@ -134,7 +147,7 @@ impl Engine {
ply: Ply,
ctrl: &Control,
) -> Result<Pv<N>, Interrupted> {
if ply < N && depth > ply && bounds.start + 1 < bounds.end {
if ply.cast::<usize>() < N && depth > ply && bounds.start + 1 < bounds.end {
self.pvs(pos, bounds, depth, ply, ctrl)
} else {
Ok(self.pvs::<0>(pos, bounds, depth, ply, ctrl)?.convert())
Expand Down Expand Up @@ -267,7 +280,7 @@ impl Engine {
} else if killers.contains(m) {
(m, Value::new(25))
} else if m.is_quiet() {
(m, Value::lower())
(m, Value::lower() / 2 + self.history.get(m, pos.turn()))
} else {
let mut next = pos.material();
let material = next.evaluate();
Expand Down Expand Up @@ -313,7 +326,7 @@ impl Engine {
};

if tail >= beta || moves.is_empty() {
self.record(pos, bounds, depth, ply, head, tail.score());
self.record(pos, &[], bounds, depth, ply, head, tail.score());
return Ok(head >> tail);
}

Expand All @@ -327,7 +340,7 @@ impl Engine {
next.play(m);

self.tt.prefetch(next.zobrist());
if gain < 0 && !pos.is_check() && !next.is_check() {
if gain <= Value::lower() / 2 && !pos.is_check() && !next.is_check() {
if let Some(d) = self.lmp(alpha + next.evaluate(), draft) {
if d <= 0 || -self.nw::<0>(&next, -alpha, d + ply, ply + 1, ctrl)? <= alpha {
#[cfg(not(test))]
Expand All @@ -345,7 +358,7 @@ impl Engine {
Ok(partial)
})?;

self.record(pos, bounds, depth, ply, head, tail.score());
self.record(pos, &moves, bounds, depth, ply, head, tail.score());
Ok(head >> tail)
}

Expand Down
57 changes: 57 additions & 0 deletions lib/search/history.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::chess::{Color, Move};
use crate::util::Assume;
use std::array;
use std::sync::atomic::{AtomicI8, Ordering::Relaxed};

/// [Historical statistics] about a [`Move`].
///
/// [Historical statistics]: https://www.chessprogramming.org/History_Heuristic
#[derive(Debug)]
pub struct History([[[AtomicI8; 2]; 64]; 64]);

impl Default for History {
#[inline(always)]
fn default() -> Self {
History(array::from_fn(|_| {
array::from_fn(|_| [AtomicI8::new(0), AtomicI8::new(0)])
}))
}
}

impl History {
/// Update statistics about a [`Move`] for a side to move at a given draft.
#[inline(always)]
pub fn update(&self, m: Move, side: Color, bonus: i8) {
let bonus = bonus.max(-i8::MAX);
let slot = &self.0[m.whence() as usize][m.whither() as usize][side as usize];
let result = slot.fetch_update(Relaxed, Relaxed, |h| {
Some((bonus as i16 - bonus.abs() as i16 * h as i16 / 127 + h as i16) as i8)
});

result.assume();
}

/// Returns the history bonus for a [`Move`].
#[inline(always)]
pub fn get(&self, m: Move, side: Color) -> i8 {
self.0[m.whence() as usize][m.whither() as usize][side as usize].load(Relaxed)
}
}

#[cfg(test)]
mod tests {
use super::*;
use test_strategy::proptest;

#[proptest]
fn update_only_changes_history_of_given_move(
c: Color,
b: i8,
m: Move,
#[filter((#m.whence(), #m.whither()) != (#n.whence(), #n.whither()))] n: Move,
) {
let h = History::default();
h.update(m, c, b);
assert_eq!(h.get(n, c), 0);
}
}
4 changes: 2 additions & 2 deletions lib/search/killers.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::chess::{Color, Move};
use crate::search::Ply;
use crate::util::{Assume, Binary, Bits, Integer};
use std::sync::atomic::AtomicU32;
use std::{array, sync::atomic::Ordering::Relaxed};
use std::array;
use std::sync::atomic::{AtomicU32, Ordering::Relaxed};

/// A pair of [killer moves].
///
Expand Down
6 changes: 1 addition & 5 deletions lib/util/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ impl<T: Unsigned, const W: u32> Binary for Bits<T, W> {
}
}

impl<T> Binary for Option<T>
where
T: Binary,
T::Bits: Default + Debug + Eq + PartialEq,
{
impl<T: Binary<Bits: Default + Debug + Eq + PartialEq>> Binary for Option<T> {
type Bits = T::Bits;

#[inline(always)]
Expand Down
28 changes: 14 additions & 14 deletions lib/util/saturating.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::util::Integer;
use crate::util::{Integer, Signed};
use derive_more::{Debug, Display, Error};
use std::fmt::{self, Formatter};
use std::ops::{Add, Div, Mul, Neg, Sub};
Expand All @@ -9,19 +9,19 @@ use std::{cmp::Ordering, mem::size_of, num::Saturating as S, str::FromStr};
#[cfg_attr(test, derive(test_strategy::Arbitrary))]
#[cfg_attr(test, arbitrary(bound(T, Self: Debug)))]
#[debug("Saturating({self})")]
#[debug(bounds(T: Integer, T::Repr: Display))]
#[debug(bounds(T: Integer<Repr: Signed>, T::Repr: Display))]
#[repr(transparent)]
pub struct Saturating<T>(T);

unsafe impl<T: Integer> Integer for Saturating<T> {
unsafe impl<T: Integer<Repr: Signed>> Integer for Saturating<T> {
type Repr = T::Repr;
const MIN: Self::Repr = T::MIN;
const MAX: Self::Repr = T::MAX;
}

impl<T: Integer> Eq for Saturating<T> where Self: PartialEq<Self> {}
impl<T: Integer<Repr: Signed>> Eq for Saturating<T> where Self: PartialEq<Self> {}

impl<T: Integer, U: Integer> PartialEq<U> for Saturating<T> {
impl<T: Integer<Repr: Signed>, U: Integer<Repr: Signed>> PartialEq<U> for Saturating<T> {
#[inline(always)]
fn eq(&self, other: &U) -> bool {
if size_of::<T>() > size_of::<U>() {
Expand All @@ -32,14 +32,14 @@ impl<T: Integer, U: Integer> PartialEq<U> for Saturating<T> {
}
}

impl<T: Integer> Ord for Saturating<T> {
impl<T: Integer<Repr: Signed>> Ord for Saturating<T> {
#[inline(always)]
fn cmp(&self, other: &Self) -> Ordering {
self.get().cmp(&other.get())
}
}

impl<T: Integer, U: Integer> PartialOrd<U> for Saturating<T> {
impl<T: Integer<Repr: Signed>, U: Integer<Repr: Signed>> PartialOrd<U> for Saturating<T> {
#[inline(always)]
fn partial_cmp(&self, other: &U) -> Option<Ordering> {
if size_of::<T>() > size_of::<U>() {
Expand All @@ -50,7 +50,7 @@ impl<T: Integer, U: Integer> PartialOrd<U> for Saturating<T> {
}
}

impl<T: Integer> Neg for Saturating<T>
impl<T: Integer<Repr: Signed>> Neg for Saturating<T>
where
S<T::Repr>: Neg<Output = S<T::Repr>>,
{
Expand All @@ -62,7 +62,7 @@ where
}
}

impl<T: Integer, U: Integer> Add<U> for Saturating<T>
impl<T: Integer<Repr: Signed>, U: Integer<Repr: Signed>> Add<U> for Saturating<T>
where
S<T::Repr>: Add<Output = S<T::Repr>>,
S<U::Repr>: Add<Output = S<U::Repr>>,
Expand All @@ -79,7 +79,7 @@ where
}
}

impl<T: Integer, U: Integer> Sub<U> for Saturating<T>
impl<T: Integer<Repr: Signed>, U: Integer<Repr: Signed>> Sub<U> for Saturating<T>
where
S<T::Repr>: Sub<Output = S<T::Repr>>,
S<U::Repr>: Sub<Output = S<U::Repr>>,
Expand All @@ -96,7 +96,7 @@ where
}
}

impl<T: Integer, U: Integer> Mul<U> for Saturating<T>
impl<T: Integer<Repr: Signed>, U: Integer<Repr: Signed>> Mul<U> for Saturating<T>
where
S<T::Repr>: Mul<Output = S<T::Repr>>,
S<U::Repr>: Mul<Output = S<U::Repr>>,
Expand All @@ -113,7 +113,7 @@ where
}
}

impl<T: Integer, U: Integer> Div<U> for Saturating<T>
impl<T: Integer<Repr: Signed>, U: Integer<Repr: Signed>> Div<U> for Saturating<T>
where
S<T::Repr>: Div<Output = S<T::Repr>>,
S<U::Repr>: Div<Output = S<U::Repr>>,
Expand All @@ -130,7 +130,7 @@ where
}
}

impl<T: Integer> Display for Saturating<T>
impl<T: Integer<Repr: Signed>> Display for Saturating<T>
where
T::Repr: Display,
{
Expand All @@ -144,7 +144,7 @@ where
#[display("failed to parse saturating integer")]
pub struct ParseSaturatingIntegerError;

impl<T: Integer> FromStr for Saturating<T>
impl<T: Integer<Repr: Signed>> FromStr for Saturating<T>
where
T::Repr: FromStr,
{
Expand Down

0 comments on commit 6c05948

Please sign in to comment.