Skip to content

Commit

Permalink
update NNUE architecture to HalfKAv2 -> 512x2 -> 1x8 with PSQT
Browse files Browse the repository at this point in the history
  • Loading branch information
brunocodutra committed Dec 10, 2023
1 parent f1c0b00 commit 0e58df7
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 130 deletions.
36 changes: 5 additions & 31 deletions lib/nnue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ use std::{io, mem::transmute};
use zstd::Decoder;

mod accumulator;
mod affine;
mod crelu;
mod damp;
mod evaluator;
mod fallthrough;
mod feature;
Expand All @@ -18,9 +16,7 @@ mod value;
mod vector;

pub use accumulator::*;
pub use affine::*;
pub use crelu::*;
pub use damp::*;
pub use evaluator::*;
pub use fallthrough::*;
pub use feature::*;
Expand All @@ -39,33 +35,24 @@ lazy_static::lazy_static! {
Nnue::load(include_bytes!("nnue/nn.zst")).expect("failed to load the NNUE");
}

type L12<N> = CReLU<Affine<Damp<N, 64>, { Nnue::L1 }, { Nnue::L2 }>>;
type L23<N> = CReLU<Affine<Damp<N, 64>, { Nnue::L2 }, { Nnue::L3 }>>;
type L3o = CReLU<Output<{ Nnue::L3 }>>;

/// An [Efficiently Updatable Neural Network][NNUE].
///
/// [NNUE]: https://www.chessprogramming.org/NNUE
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct Nnue {
ft: Transformer<i16, { Self::L0 }, { Self::L1 / 2 }>,
psqt: Transformer<i32, { Self::L0 }, { Self::PHASES }>,
nns: [L12<L23<L3o>>; Self::PHASES],
output: [CReLU<Output<{ Nnue::L1 }>>; Self::PHASES],
}

impl Nnue {
const PHASES: usize = 8;
const L0: usize = 64 * 64 * 11;
const L1: usize = 1024;
const L2: usize = 16;
const L3: usize = 32;

fn load(bytes: &[u8]) -> io::Result<Box<Self>> {
let mut buffer = Decoder::new(bytes)?;
let mut nnue: Box<Self> = unsafe { Box::new_zeroed().assume_init() };

assert_eq!(buffer.read_u32::<LittleEndian>()?, 0xffffffff);
assert_eq!(buffer.read_u32::<LittleEndian>()?, 0x3c103e72);
let mut buffer = Decoder::new(bytes)?;

buffer.read_i16_into::<LittleEndian>(&mut nnue.ft.bias)?;
buffer.read_i16_into::<LittleEndian>(unsafe {
Expand All @@ -76,22 +63,9 @@ impl Nnue {
transmute::<_, &mut [_; Self::L0 * Self::PHASES]>(&mut nnue.psqt.weight)
})?;

for nn in &mut nnue.nns {
let l12 = &mut nn.next;
buffer.read_i32_into::<LittleEndian>(&mut l12.bias)?;
buffer.read_i8_into(unsafe {
transmute::<_, &mut [_; Self::L1 * Self::L2]>(&mut l12.weight)
})?;

let l23 = &mut l12.next.next.next;
buffer.read_i32_into::<LittleEndian>(&mut l23.bias)?;
buffer.read_i8_into(unsafe {
transmute::<_, &mut [_; Self::L2 * Self::L3]>(&mut l23.weight)
})?;

let l3o = &mut l23.next.next.next;
l3o.bias = buffer.read_i32::<LittleEndian>()?;
buffer.read_i8_into(&mut l3o.weight)?;
for nn in &mut nnue.output {
nn.next.bias = buffer.read_i32::<LittleEndian>()?;
buffer.read_i8_into(&mut nn.next.weight)?;
}

debug_assert!(buffer.read_u8().is_err());
Expand Down
46 changes: 0 additions & 46 deletions lib/nnue/affine.rs

This file was deleted.

31 changes: 0 additions & 31 deletions lib/nnue/damp.rs

This file was deleted.

Binary file modified lib/nnue/nn.zst
Binary file not shown.
2 changes: 1 addition & 1 deletion lib/nnue/positional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Accumulator for Positional {

fn evaluate(&self, phase: usize) -> i32 {
let l1: &[i16; Nnue::L1] = unsafe { transmute(&self.0) };
NNUE.nns[phase].forward(l1) / 16
NNUE.output[phase].forward(l1) / 16
}
}

Expand Down
21 changes: 0 additions & 21 deletions lib/nnue/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,6 @@ impl Axpy<i8x16, i8x16> for i32x4 {
}
}

impl<const I: usize, const O: usize> Axpy<[[i8; I]; O], [i8; I]> for [i32; O] {
fn axpy(&mut self, a: &[[i8; I]; O], x: &[i8; I]) {
for (o, y) in self.iter_mut().enumerate() {
y.axpy(&a[o], x);
}
}
}

impl<T, const I: usize, const O: usize> Axpy<[[T; O]; I], [u16]> for [T; O]
where
T: Copy + AddAssign,
Expand Down Expand Up @@ -231,19 +223,6 @@ mod tests {
assert_eq!(y, a.iter().zip(x).map(|(&a, x)| a as i32 * x as i32).sum());
}

#[proptest]
fn axpy_computes_inner_product_of_matrix_and_vector(a: [[i8; 50]; 10], x: [i8; 50]) {
let x = x.map(|v| v.max(0));

let mut y = [0; 10];
y.axpy(&a, &x);

assert_eq!(
y,
a.map(|a| a.iter().zip(x).map(|(&a, x)| a as i32 * x as i32).sum())
);
}

#[proptest]
fn axpy_swizzles_matrix(
#[strategy([[-10..10i8, -10..10i8], [-10..10i8, -10..10i8], [-10..10i8, -10..10i8]])]
Expand Down

0 comments on commit 0e58df7

Please sign in to comment.