Skip to content

Commit

Permalink
Move edge cases to gemm ukr; more user-custom mods. (#583)
Browse files Browse the repository at this point in the history
Details:
- Moved edge-case handling into the gemm microkernel. This required
  changing the microkernel API to take m and n dimension parameters.
  This required updating all existing gemm microkernel function pointer
  types, function signatures, and related definitions to take m and n
  dimensions. We also updated all existing kernels in the 'kernels' 
  directory to take m and n dimensions, and implemented edge-case 
  handling within those microkernels via a collection of new C 
  preprocessor macros defined within bli_edge_case_macro_defs.h. Also
  removed the assembly code that formerly would handle general stride 
  IO on the microtile, since this can now be handled by the same code
  that does edge cases.
- Pass the obj_t.ker_fn (of matrix C) into bli_gemm_cntl_create() and
  bli_trsm_cntl_create(), where this function pointer is used in lieu of 
  the default macrokernel when it is non-NULL, and ignored when it is
  NULL.
- Re-implemented macrokernel in bli_gemm_ker_var2.c to be a single
  function using byte pointers rather that one function for each
  floating-point datatype. Also, obtain the microkernel function pointer
  from the .ukr field of the params struct embedded within the obj_t
  for matrix C (assuming params is non-NULL and contains a non-NULL
  value in the .ukr field). Communicate both the gemm microkernel
  pointer to use as well as the params struct to the microkernel via
  the auxinfo_t struct.
- Defined gemm_ker_params_t type (for the aforementioned obj_t.params 
  struct) in bli_gemm_var.h.
- Retired the separate _md macrokernel for mixed datatype computation.
  We now use the reimplemented bli_gemm_ker_var2() instead.
- Updated gemmt macrokernels to pass m and n dimensions into microkernel
  calls.
- Removed edge-case handling from trmm and trsm macrokernels.
- Moved most of bli_packm_alloc() code into a new helper function,
  bli_packm_alloc_ex().
- Fixed a typo bug in bli_gemmtrsm_u_template_noopt_mxn.c.
- Added test/syrk_diagonal and test/tensor_contraction directories with
  associated code to test those operations.
  • Loading branch information
devinamatthews authored Dec 24, 2021
1 parent 961d9d5 commit 54fa28b
Show file tree
Hide file tree
Showing 87 changed files with 10,458 additions and 13,506 deletions.
13 changes: 7 additions & 6 deletions config/template/kernels/3/bli_gemm_template_noopt_mxn.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

void bli_zgemm_template_noopt
(
dim_t m,
dim_t n,
dim_t k,
dcomplex* restrict alpha,
dcomplex* restrict a1,
Expand Down Expand Up @@ -88,8 +90,7 @@ void bli_zgemm_template_noopt

dim_t l, j, i;

dcomplex ab[ bli_zmr *
bli_znr ];
dcomplex ab[ mr * nr ];
dcomplex* abij;
dcomplex ai, bj;

Expand Down Expand Up @@ -137,16 +138,16 @@ void bli_zgemm_template_noopt
if ( bli_zeq0( *beta ) )
{
/* c11 := ab */
bli_zcopys_mxn( mr,
nr,
bli_zcopys_mxn( m,
n,
ab, rs_ab, cs_ab,
c11, rs_c, cs_c );
}
else
{
/* c11 := beta * c11 + ab */
bli_zxpbys_mxn( mr,
nr,
bli_zxpbys_mxn( m,
n,
ab, rs_ab, cs_ab,
beta,
c11, rs_c, cs_c );
Expand Down
4 changes: 4 additions & 0 deletions config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ void bli_zgemmtrsm_l_template_noopt
*/
const num_t dt = BLIS_DCOMPLEX;

const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );

const inc_t rs_b = packnr;
Expand All @@ -84,6 +86,8 @@ void bli_zgemmtrsm_l_template_noopt
/* b11 = alpha * b11 - a10 * b01; */
bli_zgemm_template_noopt
(
mr,
nr,
k,
minus_one,
a10,
Expand Down
8 changes: 6 additions & 2 deletions config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ void bli_zgemmtrsm_u_template_noopt
*/
const num_t dt = BLIS_DCOMPLEX;

const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );

const inc_t rs_b = packnr;
Expand All @@ -84,10 +86,12 @@ void bli_zgemmtrsm_u_template_noopt
/* b11 = alpha * b11 - a12 * b21; */
bli_zgemm_template_noopt
(
mr,
nr,
k,
minus_one,
a12,
b21,
a10,
b01,
alpha,
b11, rs_b, cs_b,
data
Expand Down
58 changes: 39 additions & 19 deletions frame/1m/packm/bli_packm_alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,35 @@
#include "blis.h"

void* bli_packm_alloc
(
siz_t size_needed,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
(
siz_t size_needed,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
// Query the pack buffer type from the control tree node.
packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl );

return bli_packm_alloc_ex
(
size_needed,
pack_buf_type,
rntm,
cntl,
thread
);
}

void* bli_packm_alloc_ex
(
siz_t size_needed,
packbuf_t pack_buf_type,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
// Query the address of the mem_t entry within the control tree node.
mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl );

Expand All @@ -55,7 +74,7 @@ void* bli_packm_alloc
siz_t cntl_mem_size = 0;

if ( bli_mem_is_alloc( cntl_mem_p ) )
cntl_mem_size = bli_mem_size( cntl_mem_p );
cntl_mem_size = bli_mem_size( cntl_mem_p );

if ( cntl_mem_size < size_needed )
{
Expand All @@ -64,14 +83,15 @@ void* bli_packm_alloc
// The chief thread releases the existing block associated with
// the mem_t entry in the control tree, and then re-acquires a
// new block, saving the associated mem_t entry to local_mem_s.
if ( bli_mem_is_alloc( cntl_mem_p ) )
{
bli_pba_release
(
rntm,
cntl_mem_p
);
}
if ( bli_mem_is_alloc( cntl_mem_p ) )
{
bli_pba_release
(
rntm,
cntl_mem_p
);
}

bli_pba_acquire_m
(
rntm,
Expand All @@ -89,11 +109,11 @@ void* bli_packm_alloc
// this thread's control tree node.
*cntl_mem_p = *local_mem_p;

// Barrier so that the master thread doesn't return from the function
// before we are done reading.
bli_thread_barrier( thread );
// Barrier so that the master thread doesn't return from the function
// before we are done reading.
bli_thread_barrier( thread );
}

return bli_mem_buffer( cntl_mem_p );
return bli_mem_buffer( cntl_mem_p );
}

23 changes: 16 additions & 7 deletions frame/1m/packm/bli_packm_alloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,20 @@
*/

BLIS_EXPORT_BLIS void* bli_packm_alloc
(
siz_t size_needed,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);
BLIS_EXPORT_BLIS void* bli_packm_alloc
(
siz_t size_needed,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);

BLIS_EXPORT_BLIS void* bli_packm_alloc_ex
(
siz_t size_needed,
packbuf_t pack_buf_type,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);

18 changes: 16 additions & 2 deletions frame/3/bli_l3_cntl.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ void bli_l3_cntl_create_if
family == BLIS_GEMMT ||
family == BLIS_TRMM )
{
*cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b );
*cntl_use = bli_gemm_cntl_create
(
rntm,
family,
schema_a,
schema_b,
bli_obj_ker_fn( c )
);
}
else // if ( family == BLIS_TRSM )
{
Expand All @@ -66,7 +73,14 @@ void bli_l3_cntl_create_if
if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT;
else side = BLIS_RIGHT;

*cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b );
*cntl_use = bli_trsm_cntl_create
(
rntm,
side,
schema_a,
schema_b,
bli_obj_ker_fn( c )
);
}
}
else
Expand Down
2 changes: 2 additions & 0 deletions frame/3/bli_l3_ft_ukr.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
\
typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \
( \
dim_t m, \
dim_t n, \
dim_t k, \
ctype* restrict alpha, \
ctype* restrict a, \
Expand Down
4 changes: 4 additions & 0 deletions frame/3/bli_l3_ukr_oapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void PASTEMAC0(opname) \
\
num_t dt = bli_obj_dt( c ); \
\
dim_t m = bli_obj_length( c ); \
dim_t n = bli_obj_width( c ); \
dim_t k = bli_obj_width( a ); \
void* buf_a = bli_obj_buffer_at_off( a ); \
void* buf_b = bli_obj_buffer_at_off( b ); \
Expand All @@ -75,6 +77,8 @@ void PASTEMAC0(opname) \
\
f \
( \
m, \
n, \
k, \
buf_alpha, \
buf_a, \
Expand Down
2 changes: 2 additions & 0 deletions frame/3/bli_l3_ukr_prot.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
\
void PASTEMAC(ch,opname) \
( \
dim_t m, \
dim_t n, \
dim_t k, \
ctype_out* restrict alpha, \
ctype_in* restrict a, \
Expand Down
63 changes: 35 additions & 28 deletions frame/3/bli_l3_ukr_tapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
\
void PASTEMAC(ch,opname) \
( \
dim_t m, \
dim_t n, \
dim_t k, \
ctype* restrict alpha, \
ctype* restrict a, \
Expand All @@ -58,16 +60,19 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\
/* Invoke the typed function for the given datatype. */ \
f( \
k, \
alpha, \
a, \
b, \
beta, \
c, rs_c, cs_c, \
data, \
cntx \
); \
f \
( \
m, \
n, \
k, \
alpha, \
a, \
b, \
beta, \
c, rs_c, cs_c, \
data, \
cntx \
); \
} \

INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR )
Expand Down Expand Up @@ -98,17 +103,18 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\
/* Invoke the typed function for the given datatype. */ \
f( \
k, \
alpha, \
a1x, \
a11, \
bx1, \
b11, \
c11, rs_c, cs_c, \
data, \
cntx \
); \
f \
( \
k, \
alpha, \
a1x, \
a11, \
bx1, \
b11, \
c11, rs_c, cs_c, \
data, \
cntx \
); \
} \

INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR )
Expand Down Expand Up @@ -136,13 +142,14 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\
/* Invoke the typed function for the given datatype. */ \
f( \
a, \
b, \
c, rs_c, cs_c, \
data, \
cntx \
); \
f \
( \
a, \
b, \
c, rs_c, cs_c, \
data, \
cntx \
); \
} \

INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR )
Expand Down
15 changes: 10 additions & 5 deletions frame/3/gemm/bli_gemm_cntl.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ cntl_t* bli_gemm_cntl_create
rntm_t* rntm,
opid_t family,
pack_t schema_a,
pack_t schema_b
pack_t schema_b,
void_fp ker
)
{
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b );
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b, ker );
}

// -----------------------------------------------------------------------------
Expand All @@ -53,18 +54,22 @@ cntl_t* bli_gemmbp_cntl_create
rntm_t* rntm,
opid_t family,
pack_t schema_a,
pack_t schema_b
pack_t schema_b,
void_fp ker
)
{
void_fp macro_kernel_fp;

// Use the function pointers to the macrokernels that use slab
// assignment of micropanels to threads in the jr and ir loops.
// Choose the default macrokernel based on the operation family...
if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2;
else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2;
else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2;
else /* should never execute */ macro_kernel_fp = NULL;

// ...unless a non-NULL kernel function pointer is passed in, in which
// case we use that instead.
if ( ker ) macro_kernel_fp = ker;

// Create two nodes for the macro-kernel.
cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node
(
Expand Down
Loading

0 comments on commit 54fa28b

Please sign in to comment.