Skip to content

Commit

Permalink
FEAT: Use type parameter for size in pack
Browse files Browse the repository at this point in the history
This function should be specialized depending on the kernel row length,
so that its copy_nonoverlapping call gets instantiated with a fixed
length inline copy.

This is necessary now that we have many different kernel configurations.
  • Loading branch information
bluss committed Dec 7, 2018
1 parent 6759f8a commit db77e4c
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use aligned_alloc::Alloc;
use util::range_chunk;
use util::round_up_to;

use kernel::GemmKernel;
use kernel::ConstNum;
use kernel::Element;
use kernel::GemmKernel;
use kernel::GemmSelect;
use sgemm_kernel;
use dgemm_kernel;
Expand Down Expand Up @@ -189,7 +190,7 @@ unsafe fn gemm_loop<K>(
let a = a.stride_offset(csa, kkc * l4);

// Pack B -> B~
pack(kc, nc, K::NR, bpp, b, csb, rsb);
pack::<K::NRTy, _>(kc, nc, bpp, b, csb, rsb);

// LOOP 3: split m into mc parts
for (l3, mc) in range_chunk(m, kmc) {
Expand All @@ -198,7 +199,7 @@ unsafe fn gemm_loop<K>(
let c = c.stride_offset(rsc, kmc * l3);

// Pack A -> A~
pack(kc, mc, K::MR, app, a, rsa, csa);
pack::<K::MRTy, _>(kc, mc, app, a, rsa, csa);

// First time writing to C, use user's `beta`, else accumulate
let betap = if l4 == 0 { beta } else { <_>::one() };
Expand Down Expand Up @@ -305,15 +306,18 @@ unsafe fn align_ptr<U>(align_to: usize, mut ptr: *mut U) -> *mut U {
///
/// + kc: length of the micropanel
/// + mc: number of rows/columns in the matrix to be packed
/// + mr: kernel rows/columns that we round up to
/// + pack: packing buffer
/// + a: matrix,
/// + rsa: row stride
/// + csa: column stride
unsafe fn pack<T>(kc: usize, mc: usize, mr: usize, pack: *mut T,
a: *const T, rsa: isize, csa: isize)
where T: Element
///
/// + MR: kernel rows/columns that we round up to
unsafe fn pack<MR, T>(kc: usize, mc: usize, pack: *mut T,
a: *const T, rsa: isize, csa: isize)
where T: Element,
MR: ConstNum,
{
let mr = MR::VALUE;
let mut p = 0; // offset into pack

if rsa == 1 {
Expand Down

0 comments on commit db77e4c

Please sign in to comment.