Skip to content

Commit

Permalink
FIX: move kernel-specific constants into each kernel function
Browse files Browse the repository at this point in the history
Move the MR/NR constants into each function, since they don't make sense
for the whole file anymore.
  • Loading branch information
bluss committed Dec 7, 2018
1 parent 9753437 commit c71f07f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 16 deletions.
15 changes: 5 additions & 10 deletions src/dgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,11 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
return selector.select(KernelFallback);
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
const MR: usize = 8;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
const NR: usize = 4;

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_m {
($i:ident, $e:expr) => { loop8!($i, $e) };
}
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_n {
($j:ident, $e:expr) => { loop4!($j, $e) };
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelAvx {
Expand Down Expand Up @@ -235,6 +227,9 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
where MA: DMultiplyAdd
{
const MR: usize = KernelAvx::MR;
const NR: usize = KernelAvx::NR;

debug_assert_ne!(k, 0);

let mut ab = [_mm256_setzero_pd(); MR];
Expand Down Expand Up @@ -865,8 +860,8 @@ mod tests {

#[test]
fn test_loop_m_n() {
let mut m = [[0; NR]; MR];
loop_m!(i, loop_n!(j, m[i][j] += 1));
let mut m = [[0; 4]; KernelAvx::MR];
loop_m!(i, loop4!(j, m[i][j] += 1));
for arr in &m[..] {
for elt in &arr[..] {
assert_eq!(*elt, 1);
Expand Down
10 changes: 4 additions & 6 deletions src/sgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
return selector.select(KernelFallback);
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
const MR: usize = 8;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
const NR: usize = 8;

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; }
#[cfg(test)]
Expand Down Expand Up @@ -220,6 +215,9 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
where MA: SMultiplyAdd,
{
const MR: usize = KernelAvx::MR;
const NR: usize = KernelAvx::NR;

debug_assert_ne!(k, 0);

let mut ab = [_mm256_setzero_ps(); MR];
Expand Down Expand Up @@ -541,7 +539,7 @@ mod tests {

#[test]
fn test_loop_m_n() {
let mut m = [[0; NR]; MR];
let mut m = [[0; KernelAvx::NR]; KernelAvx::MR];
loop_m!(i, loop_n!(j, m[i][j] += 1));
for arr in &m[..] {
for elt in &arr[..] {
Expand Down

0 comments on commit c71f07f

Please sign in to comment.