Skip to content

Commit

Permalink
replace crate lazy_static by crate ctor
Browse files Browse the repository at this point in the history
  • Loading branch information
brunocodutra committed Jan 28, 2024
1 parent 265b3e6 commit 0a8b7e5
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 42 deletions.
12 changes: 11 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ keywords = ["chess"]
arrayvec = { version = "0.7.3", default-features = false, features = ["std"] }
byteorder = { version = "1.5.0", default-features = false, features = ["std"] }
cozy-chess = { version = "0.3.3", default-features = false, features = ["std"] }
ctor = { version = "0.2.6", default-features = false }
derive_more = { version = "1.0.0-beta.6", default-features = false, features = [
"add",
"add_assign",
Expand All @@ -28,7 +29,6 @@ derive_more = { version = "1.0.0-beta.6", default-features = false, features = [
"mul_assign",
"not",
] }
lazy_static = { version = "1.4.0", default-features = false }
num-traits = { version = "0.2.17", default-features = false, features = ["std"] }
rayon = { version = "1.8.1", default-features = false }
ruzstd = { version = "0.5.0", default-features = false, features = ["std"] }
Expand Down
2 changes: 1 addition & 1 deletion lib/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![feature(array_chunks, new_uninit, stdarch)]
#![feature(array_chunks, stdarch)]

/// Chess domain types.
pub mod chess;
Expand Down
38 changes: 19 additions & 19 deletions lib/nnue.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use byteorder::{LittleEndian, ReadBytesExt};
use ruzstd::StreamingDecoder;
use std::io::{self, Read};
use std::mem::transmute;
use std::mem::{transmute, MaybeUninit};

mod accumulator;
mod evaluator;
Expand All @@ -23,49 +23,49 @@ pub use positional::*;
pub use transformer::*;
pub use value::*;

lazy_static::lazy_static! {
/// A trained [`Nnue`].
pub static ref NNUE: Box<Nnue> = {
let encoded = include_bytes!("nnue/nn.zst").as_slice();
let decoder = StreamingDecoder::new(encoded).expect("failed to initialize zstd decoder");
Nnue::load(decoder).expect("failed to load the NNUE")
};
}

/// An [Efficiently Updatable Neural Network][NNUE].
///
/// [NNUE]: https://www.chessprogramming.org/NNUE
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct Nnue {
struct Nnue {
ft: Transformer<i16, { Self::L0 }, { Self::L1 / 2 }>,
psqt: Transformer<i32, { Self::L0 }, { Self::PHASES }>,
output: [Output<{ Nnue::L1 }>; Self::PHASES],
}

static mut NNUE: Nnue = unsafe { MaybeUninit::zeroed().assume_init() };

#[cold]
#[ctor::ctor]
#[inline(never)]
unsafe fn init() {
let encoded = include_bytes!("nnue/nn.zst").as_slice();
let decoder = StreamingDecoder::new(encoded).expect("failed to initialize zstd decoder");
NNUE.load(decoder).expect("failed to load the NNUE");
}

impl Nnue {
const PHASES: usize = 8;
const L0: usize = 64 * 64 * 11;
const L1: usize = 1024;

fn load<T: Read>(mut reader: T) -> io::Result<Box<Self>> {
let mut nnue: Box<Self> = unsafe { Box::new_zeroed().assume_init() };

reader.read_i16_into::<LittleEndian>(&mut *nnue.ft.bias)?;
fn load<T: Read>(&mut self, mut reader: T) -> io::Result<()> {
reader.read_i16_into::<LittleEndian>(&mut *self.ft.bias)?;
reader.read_i16_into::<LittleEndian>(unsafe {
transmute::<_, &mut [_; Self::L0 * Self::L1 / 2]>(&mut *nnue.ft.weight)
transmute::<_, &mut [_; Self::L0 * Self::L1 / 2]>(&mut *self.ft.weight)
})?;

reader.read_i32_into::<LittleEndian>(unsafe {
transmute::<_, &mut [_; Self::L0 * Self::PHASES]>(&mut *nnue.psqt.weight)
transmute::<_, &mut [_; Self::L0 * Self::PHASES]>(&mut *self.psqt.weight)
})?;

for nn in &mut nnue.output {
for nn in &mut self.output {
nn.bias = reader.read_i32::<LittleEndian>()?;
reader.read_i8_into(&mut *nn.weight)?;
}

debug_assert!(reader.read_u8().is_err());

Ok(nnue)
Ok(())
}
}
22 changes: 15 additions & 7 deletions lib/nnue/material.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::nnue::{Accumulator, Nnue, NNUE};
use crate::nnue::{Accumulator, Nnue, Transformer, NNUE};
use crate::util::AlignTo64;
use std::ops::Deref;

/// An accumulator for the psqt transformer.
#[derive(Debug, Default, Clone, Eq, PartialEq, Hash)]
Expand All @@ -9,6 +10,13 @@ pub struct Material(
AlignTo64<[[i32; Nnue::PHASES]; 2]>,
);

impl Material {
#[inline(always)]
fn transformer(&self) -> impl Deref<Target = Transformer<i32, { Nnue::L0 }, { Nnue::PHASES }>> {
unsafe { &NNUE.psqt }
}
}

impl Accumulator for Material {
#[inline(always)]
fn mirror(&mut self) {
Expand All @@ -17,20 +25,20 @@ impl Accumulator for Material {

#[inline(always)]
fn refresh(&mut self, us: &[u16], them: &[u16]) {
NNUE.psqt.refresh(us, &mut self.0[0]);
NNUE.psqt.refresh(them, &mut self.0[1]);
self.transformer().refresh(us, &mut self.0[0]);
self.transformer().refresh(them, &mut self.0[1]);
}

#[inline(always)]
fn add(&mut self, us: u16, them: u16) {
NNUE.psqt.add(us, &mut self.0[0]);
NNUE.psqt.add(them, &mut self.0[1]);
self.transformer().add(us, &mut self.0[0]);
self.transformer().add(them, &mut self.0[1]);
}

#[inline(always)]
fn remove(&mut self, us: u16, them: u16) {
NNUE.psqt.remove(us, &mut self.0[0]);
NNUE.psqt.remove(them, &mut self.0[1]);
self.transformer().remove(us, &mut self.0[0]);
self.transformer().remove(them, &mut self.0[1]);
}

#[inline(always)]
Expand Down
30 changes: 20 additions & 10 deletions lib/nnue/positional.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::nnue::{Accumulator, Layer, Nnue, NNUE};
use crate::nnue::{Accumulator, Layer, Nnue, Transformer, NNUE};
use crate::util::AlignTo64;
use std::mem::transmute;
use std::{mem::transmute, ops::Deref};

/// An accumulator for the feature transformer.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
Expand All @@ -10,7 +10,15 @@ pub struct Positional(
AlignTo64<[[i16; Nnue::L1 / 2]; 2]>,
);

impl Positional {
#[inline(always)]
fn transformer(&self) -> impl Deref<Target = Transformer<i16, { Nnue::L0 }, { Nnue::L1 / 2 }>> {
unsafe { &NNUE.ft }
}
}

impl Default for Positional {
#[inline(always)]
fn default() -> Self {
Positional(AlignTo64([[0; Nnue::L1 / 2]; 2]))
}
Expand All @@ -24,26 +32,28 @@ impl Accumulator for Positional {

#[inline(always)]
fn refresh(&mut self, us: &[u16], them: &[u16]) {
NNUE.ft.refresh(us, &mut self.0[0]);
NNUE.ft.refresh(them, &mut self.0[1]);
self.transformer().refresh(us, &mut self.0[0]);
self.transformer().refresh(them, &mut self.0[1]);
}

#[inline(always)]
fn add(&mut self, us: u16, them: u16) {
NNUE.ft.add(us, &mut self.0[0]);
NNUE.ft.add(them, &mut self.0[1]);
self.transformer().add(us, &mut self.0[0]);
self.transformer().add(them, &mut self.0[1]);
}

#[inline(always)]
fn remove(&mut self, us: u16, them: u16) {
NNUE.ft.remove(us, &mut self.0[0]);
NNUE.ft.remove(them, &mut self.0[1]);
self.transformer().remove(us, &mut self.0[0]);
self.transformer().remove(them, &mut self.0[1]);
}

#[inline(always)]
fn evaluate(&self, phase: usize) -> i32 {
let l1: &[i16; Nnue::L1] = unsafe { transmute(&self.0) };
NNUE.output[phase].forward(l1) / 16
unsafe {
let l1: &[i16; Nnue::L1] = transmute(&self.0);
NNUE.output[phase].forward(l1) / 16
}
}
}

Expand Down
4 changes: 1 addition & 3 deletions lib/search/engine.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::chess::{Move, Piece, Position, Role};
use crate::nnue::{Evaluator, Value, NNUE};
use crate::nnue::{Evaluator, Value};
use crate::search::{Depth, DepthBounds, Killers, Limits, Options, Ply, Pv, Score};
use crate::search::{Transposition, TranspositionTable};
use crate::util::{Assume, Bounds, Buffer, Counter, Timer};
Expand Down Expand Up @@ -56,8 +56,6 @@ impl Engine {

/// Initializes the engine with the given [`Options`].
pub fn with_options(options: Options) -> Self {
lazy_static::initialize(&NNUE);

Engine {
tt: TranspositionTable::new(options.hash),
executor: ThreadPoolBuilder::new()
Expand Down

0 comments on commit 0a8b7e5

Please sign in to comment.