Skip to content

Commit

Permalink
FEAT: Add ConstNum for simple type level integers
Browse files Browse the repository at this point in the history
This stands in for const generics, we only need a few sizes anyway. Use
it for kernel sizes.
  • Loading branch information
bluss committed Dec 7, 2018
1 parent 805ec95 commit 6759f8a
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 73 deletions.
41 changes: 13 additions & 28 deletions src/dgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

use kernel::GemmKernel;
use kernel::GemmSelect;
use kernel::{U4, U8};
use archparam;

#[cfg(target_arch="x86")]
Expand Down Expand Up @@ -65,16 +66,12 @@ macro_rules! loop_n {
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelAvx {
type Elem = T;
const MR: usize = MR;
const NR: usize = NR;

#[inline(always)]
fn align_to() -> usize { 32 }
type MRTy = U8;
type NRTy = U4;

#[inline(always)]
fn mr() -> usize { MR }
#[inline(always)]
fn nr() -> usize { NR }
fn align_to() -> usize { 32 }

#[inline(always)]
fn always_masked() -> bool { false }
Expand Down Expand Up @@ -104,16 +101,12 @@ impl GemmKernel for KernelAvx {
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelFma {
type Elem = T;
const MR: usize = KernelAvx::MR;
const NR: usize = KernelAvx::NR;

#[inline(always)]
fn align_to() -> usize { KernelAvx::align_to() }
type MRTy = <KernelAvx as GemmKernel>::MRTy;
type NRTy = <KernelAvx as GemmKernel>::NRTy;

#[inline(always)]
fn mr() -> usize { MR }
#[inline(always)]
fn nr() -> usize { NR }
fn align_to() -> usize { KernelAvx::align_to() }

#[inline(always)]
fn always_masked() -> bool { KernelAvx::always_masked() }
Expand Down Expand Up @@ -143,16 +136,12 @@ impl GemmKernel for KernelFma {
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelSse2 {
type Elem = T;
const MR: usize = 4;
const NR: usize = 4;

#[inline(always)]
fn align_to() -> usize { 16 }
type MRTy = U4;
type NRTy = U4;

#[inline(always)]
fn mr() -> usize { Self::MR }
#[inline(always)]
fn nr() -> usize { Self::NR }
fn align_to() -> usize { 16 }

#[inline(always)]
fn always_masked() -> bool { true }
Expand Down Expand Up @@ -181,16 +170,12 @@ impl GemmKernel for KernelSse2 {

impl GemmKernel for KernelFallback {
type Elem = T;
const MR: usize = 4;
const NR: usize = 4;

#[inline(always)]
fn align_to() -> usize { 0 }
type MRTy = U4;
type NRTy = U4;

#[inline(always)]
fn mr() -> usize { Self::MR }
#[inline(always)]
fn nr() -> usize { Self::NR }
fn align_to() -> usize { 0 }

#[inline(always)]
fn always_masked() -> bool { true }
Expand Down
20 changes: 10 additions & 10 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ impl<T> GemmSelect<T> for GemmParameters<T> {
fn ensure_kernel_params<K>()
where K: GemmKernel
{
let mr = K::mr();
let nr = K::nr();
let mr = K::MR;
let nr = K::NR;
assert!(mr > 0 && mr <= 8);
assert!(nr > 0 && nr <= 8);
assert!(mr * nr * size_of::<K::Elem>() <= 8 * 4 * 8);
Expand Down Expand Up @@ -189,7 +189,7 @@ unsafe fn gemm_loop<K>(
let a = a.stride_offset(csa, kkc * l4);

// Pack B -> B~
pack(kc, nc, K::nr(), bpp, b, csb, rsb);
pack(kc, nc, K::NR, bpp, b, csb, rsb);

// LOOP 3: split m into mc parts
for (l3, mc) in range_chunk(m, kmc) {
Expand All @@ -198,7 +198,7 @@ unsafe fn gemm_loop<K>(
let c = c.stride_offset(rsc, kmc * l3);

// Pack A -> A~
pack(kc, mc, K::mr(), app, a, rsa, csa);
pack(kc, mc, K::MR, app, a, rsa, csa);

// First time writing to C, use user's `beta`, else accumulate
let betap = if l4 == 0 { beta } else { <_>::one() };
Expand Down Expand Up @@ -228,8 +228,8 @@ unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
c: *mut K::Elem, rsc: isize, csc: isize)
where K: GemmKernel,
{
let mr = K::mr();
let nr = K::nr();
let mr = K::MR;
let nr = K::NR;
// make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment
assert!(mr * nr * size_of::<K::Elem>() <= 256 && K::align_to() <= 32);
let mut mask_buf = [0u8; 256 + 31];
Expand Down Expand Up @@ -278,8 +278,8 @@ unsafe fn make_packing_buffer<K>(m: usize, k: usize, n: usize) -> (Alloc<K::Elem
let n = min(n, K::nc());
// round up k, n to multiples of mr, nr
// round up to multiple of kc
let apack_size = k * round_up_to(m, K::mr());
let bpack_size = k * round_up_to(n, K::nr());
let apack_size = k * round_up_to(m, K::MR);
let bpack_size = k * round_up_to(n, K::NR);
let nelem = apack_size + bpack_size;

dprint!("packed nelem={}, apack={}, bpack={},
Expand Down Expand Up @@ -382,8 +382,8 @@ unsafe fn masked_kernel<T, K>(k: usize, alpha: T,
mask_buf: *mut T)
where K: GemmKernel<Elem=T>, T: Element,
{
let mr = K::mr();
let nr = K::nr();
let mr = K::MR;
let nr = K::NR;
// use column major order for `mask_buf`
K::kernel(k, T::one(), a, b, T::zero(), mask_buf, 1, mr as isize);
let mut ab = mask_buf;
Expand Down
23 changes: 16 additions & 7 deletions src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@ pub trait GemmKernel {
type Elem: Element;

/// Kernel rows
const MR: usize;
const MR: usize = Self::MRTy::VALUE;
/// Kernel cols
const NR: usize;
const NR: usize = Self::NRTy::VALUE;
/// Kernel rows as static num
type MRTy: ConstNum;
/// Kernel cols as static num
type NRTy: ConstNum;

/// align inputs to this
fn align_to() -> usize;

/// Kernel rows
fn mr() -> usize;
/// Kernel cols
fn nr() -> usize;

/// Whether to always use the masked wrapper around the kernel.
///
/// If masked, the kernel is always called with α=1, β=0
Expand Down Expand Up @@ -93,3 +92,13 @@ pub(crate) trait GemmSelect<T> {
T: Element;
}


pub trait ConstNum {
const VALUE: usize;
}

pub struct U4;
pub struct U8;

impl ConstNum for U4 { const VALUE: usize = 4; }
impl ConstNum for U8 { const VALUE: usize = 8; }
40 changes: 12 additions & 28 deletions src/sgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

use kernel::GemmKernel;
use kernel::GemmSelect;
use kernel::{U4, U8};
use archparam;


Expand Down Expand Up @@ -63,15 +64,11 @@ macro_rules! loop_n { ($j:ident, $e:expr) => { loop8!($j, $e) }; }
impl GemmKernel for KernelAvx {
type Elem = T;

const MR: usize = MR;
const NR: usize = NR;
#[inline(always)]
fn align_to() -> usize { 32 }
type MRTy = U8;
type NRTy = U8;

#[inline(always)]
fn mr() -> usize { MR }
#[inline(always)]
fn nr() -> usize { NR }
fn align_to() -> usize { 32 }

#[inline(always)]
fn always_masked() -> bool { false }
Expand Down Expand Up @@ -99,17 +96,12 @@ impl GemmKernel for KernelAvx {
impl GemmKernel for KernelFma {
type Elem = T;

const MR: usize = KernelAvx::MR;
const NR: usize = KernelAvx::NR;
type MRTy = <KernelAvx as GemmKernel>::MRTy;
type NRTy = <KernelAvx as GemmKernel>::NRTy;

#[inline(always)]
fn align_to() -> usize { KernelAvx::align_to() }

#[inline(always)]
fn mr() -> usize { MR }
#[inline(always)]
fn nr() -> usize { NR }

#[inline(always)]
fn always_masked() -> bool { KernelAvx::always_masked() }

Expand All @@ -136,15 +128,11 @@ impl GemmKernel for KernelFma {
impl GemmKernel for KernelSse2 {
type Elem = T;

const MR: usize = KernelFallback::MR;
const NR: usize = KernelFallback::NR;
#[inline(always)]
fn align_to() -> usize { 16 }
type MRTy = <KernelFallback as GemmKernel>::MRTy;
type NRTy = <KernelFallback as GemmKernel>::NRTy;

#[inline(always)]
fn mr() -> usize { Self::MR }
#[inline(always)]
fn nr() -> usize { Self::NR }
fn align_to() -> usize { 16 }

#[inline(always)]
fn always_masked() -> bool { KernelFallback::always_masked() }
Expand All @@ -171,15 +159,11 @@ impl GemmKernel for KernelSse2 {
impl GemmKernel for KernelFallback {
type Elem = T;

const MR: usize = 8;
const NR: usize = 4;
#[inline(always)]
fn align_to() -> usize { 0 }
type MRTy = U8;
type NRTy = U4;

#[inline(always)]
fn mr() -> usize { Self::MR }
#[inline(always)]
fn nr() -> usize { Self::NR }
fn align_to() -> usize { 0 }

#[inline(always)]
fn always_masked() -> bool { true }
Expand Down

0 comments on commit 6759f8a

Please sign in to comment.