From 6759f8a4fac4be4e849794e1ecfbb44d3a79d5b4 Mon Sep 17 00:00:00 2001 From: bluss Date: Thu, 6 Dec 2018 18:28:36 +0100 Subject: [PATCH] FEAT: Add ConstNum for simple type level integers This stands in for const generics, we only need a few sizes anyway. Use it for kernel sizes. --- src/dgemm_kernel.rs | 41 +++++++++++++---------------------------- src/gemm.rs | 20 ++++++++++---------- src/kernel.rs | 23 ++++++++++++++++------- src/sgemm_kernel.rs | 40 ++++++++++++---------------------------- 4 files changed, 51 insertions(+), 73 deletions(-) diff --git a/src/dgemm_kernel.rs b/src/dgemm_kernel.rs index 09d6197..984ac45 100644 --- a/src/dgemm_kernel.rs +++ b/src/dgemm_kernel.rs @@ -8,6 +8,7 @@ use kernel::GemmKernel; use kernel::GemmSelect; +use kernel::{U4, U8}; use archparam; #[cfg(target_arch="x86")] @@ -65,16 +66,12 @@ macro_rules! loop_n { #[cfg(any(target_arch="x86", target_arch="x86_64"))] impl GemmKernel for KernelAvx { type Elem = T; - const MR: usize = MR; - const NR: usize = NR; - #[inline(always)] - fn align_to() -> usize { 32 } + type MRTy = U8; + type NRTy = U4; #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } + fn align_to() -> usize { 32 } #[inline(always)] fn always_masked() -> bool { false } @@ -104,16 +101,12 @@ impl GemmKernel for KernelAvx { #[cfg(any(target_arch="x86", target_arch="x86_64"))] impl GemmKernel for KernelFma { type Elem = T; - const MR: usize = KernelAvx::MR; - const NR: usize = KernelAvx::NR; - #[inline(always)] - fn align_to() -> usize { KernelAvx::align_to() } + type MRTy = ::MRTy; + type NRTy = ::NRTy; #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } + fn align_to() -> usize { KernelAvx::align_to() } #[inline(always)] fn always_masked() -> bool { KernelAvx::always_masked() } @@ -143,16 +136,12 @@ impl GemmKernel for KernelFma { #[cfg(any(target_arch="x86", target_arch="x86_64"))] impl GemmKernel for KernelSse2 { type Elem = T; - const MR: usize = 4; - const NR: usize = 4; - #[inline(always)] - fn align_to() -> usize { 16 } + type MRTy = U4; + type NRTy = U4; #[inline(always)] - fn mr() -> usize { Self::MR } - #[inline(always)] - fn nr() -> usize { Self::NR } + fn align_to() -> usize { 16 } #[inline(always)] fn always_masked() -> bool { true } @@ -181,16 +170,12 @@ impl GemmKernel for KernelSse2 { impl GemmKernel for KernelFallback { type Elem = T; - const MR: usize = 4; - const NR: usize = 4; - #[inline(always)] - fn align_to() -> usize { 0 } + type MRTy = U4; + type NRTy = U4; #[inline(always)] - fn mr() -> usize { Self::MR } - #[inline(always)] - fn nr() -> usize { Self::NR } + fn align_to() -> usize { 0 } #[inline(always)] fn always_masked() -> bool { true } diff --git a/src/gemm.rs b/src/gemm.rs index e82275b..39ad2bc 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -137,8 +137,8 @@ impl GemmSelect for GemmParameters { fn ensure_kernel_params() where K: GemmKernel { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; assert!(mr > 0 && mr <= 8); assert!(nr > 0 && nr <= 8); assert!(mr * nr * size_of::() <= 8 * 4 * 8); @@ -189,7 +189,7 @@ unsafe fn gemm_loop( let a = a.stride_offset(csa, kkc * l4); // Pack B -> B~ - pack(kc, nc, K::nr(), bpp, b, csb, rsb); + pack(kc, nc, K::NR, bpp, b, csb, rsb); // LOOP 3: split m into mc parts for (l3, mc) in range_chunk(m, kmc) { @@ -198,7 +198,7 @@ unsafe fn gemm_loop( let c = c.stride_offset(rsc, kmc * l3); // Pack A -> A~ - pack(kc, mc, K::mr(), app, a, rsa, csa); + pack(kc, mc, K::MR, app, a, rsa, csa); // First time writing to C, use user's `beta`, else accumulate let betap = if l4 == 0 { beta } else { <_>::one() }; @@ -228,8 +228,8 @@ unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, c: *mut K::Elem, rsc: isize, csc: isize) where K: GemmKernel, { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; // make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment assert!(mr * nr * size_of::() <= 256 && K::align_to() <= 32); let mut mask_buf = [0u8; 256 + 31]; @@ -278,8 +278,8 @@ unsafe fn make_packing_buffer(m: usize, k: usize, n: usize) -> (Alloc(k: usize, alpha: T, mask_buf: *mut T) where K: GemmKernel, T: Element, { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; // use column major order for `mask_buf` K::kernel(k, T::one(), a, b, T::zero(), mask_buf, 1, mr as isize); let mut ab = mask_buf; diff --git a/src/kernel.rs b/src/kernel.rs index b39adfa..cae587e 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -11,18 +11,17 @@ pub trait GemmKernel { type Elem: Element; /// Kernel rows - const MR: usize; + const MR: usize = Self::MRTy::VALUE; /// Kernel cols - const NR: usize; + const NR: usize = Self::NRTy::VALUE; + /// Kernel rows as static num + type MRTy: ConstNum; + /// Kernel cols as static num + type NRTy: ConstNum; /// align inputs to this fn align_to() -> usize; - /// Kernel rows - fn mr() -> usize; - /// Kernel cols - fn nr() -> usize; - /// Whether to always use the masked wrapper around the kernel. /// /// If masked, the kernel is always called with α=1, β=0 @@ -93,3 +92,13 @@ pub(crate) trait GemmSelect { T: Element; } + +pub trait ConstNum { + const VALUE: usize; +} + +pub struct U4; +pub struct U8; + +impl ConstNum for U4 { const VALUE: usize = 4; } +impl ConstNum for U8 { const VALUE: usize = 8; } diff --git a/src/sgemm_kernel.rs b/src/sgemm_kernel.rs index 2d708bc..5f30cf9 100644 --- a/src/sgemm_kernel.rs +++ b/src/sgemm_kernel.rs @@ -8,6 +8,7 @@ use kernel::GemmKernel; use kernel::GemmSelect; +use kernel::{U4, U8}; use archparam; @@ -63,15 +64,11 @@ macro_rules! loop_n { ($j:ident, $e:expr) => { loop8!($j, $e) }; } impl GemmKernel for KernelAvx { type Elem = T; - const MR: usize = MR; - const NR: usize = NR; - #[inline(always)] - fn align_to() -> usize { 32 } + type MRTy = U8; + type NRTy = U8; #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } + fn align_to() -> usize { 32 } #[inline(always)] fn always_masked() -> bool { false } @@ -99,17 +96,12 @@ impl GemmKernel for KernelAvx { impl GemmKernel for KernelFma { type Elem = T; - const MR: usize = KernelAvx::MR; - const NR: usize = KernelAvx::NR; + type MRTy = ::MRTy; + type NRTy = ::NRTy; #[inline(always)] fn align_to() -> usize { KernelAvx::align_to() } - #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } - #[inline(always)] fn always_masked() -> bool { KernelAvx::always_masked() } @@ -136,15 +128,11 @@ impl GemmKernel for KernelFma { impl GemmKernel for KernelSse2 { type Elem = T; - const MR: usize = KernelFallback::MR; - const NR: usize = KernelFallback::NR; - #[inline(always)] - fn align_to() -> usize { 16 } + type MRTy = ::MRTy; + type NRTy = ::NRTy; #[inline(always)] - fn mr() -> usize { Self::MR } - #[inline(always)] - fn nr() -> usize { Self::NR } + fn align_to() -> usize { 16 } #[inline(always)] fn always_masked() -> bool { KernelFallback::always_masked() } @@ -171,15 +159,11 @@ impl GemmKernel for KernelSse2 { impl GemmKernel for KernelFallback { type Elem = T; - const MR: usize = 8; - const NR: usize = 4; - #[inline(always)] - fn align_to() -> usize { 0 } + type MRTy = U8; + type NRTy = U4; #[inline(always)] - fn mr() -> usize { Self::MR } - #[inline(always)] - fn nr() -> usize { Self::NR } + fn align_to() -> usize { 0 } #[inline(always)] fn always_masked() -> bool { true }