Skip to content

Commit

Permalink
Arm64 dgemmsup with extended MR&NR (#655)
Browse files Browse the repository at this point in the history
Details:
- Since the number of registers in NEON is large but their lengths are 
  short, I'm here extending both MR and NR.
- The approach is to represent the C microtile in registers optionally 
  in columns, so for sizes like 6x7m, the 'crr' kernel is the default 
  with 'rrr' supported through an in-register transpose.
- A few asm kernels are crafted for 'rv' to complete this extended size 
  support.
- For 'rd' I'm still relying heavily on C99 intrinsic kernels with 
  branching so the performance might not be optimal. (Sorry for that.)
- So far, these changes only affect the 'firestorm' subconfig.
- This commit also contains row-preferential s12x8 and d6x8 gemm
  ukernels. These microkernels are templatized versions of the existing
  s8x12 and d6x8 ukernels defined in bli_gemm_armv8a_asm_d6x8.c.
  • Loading branch information
xrq-phys authored Aug 30, 2022
1 parent 9e5594a commit dfa5413
Show file tree
Hide file tree
Showing 16 changed files with 3,020 additions and 712 deletions.
32 changes: 17 additions & 15 deletions config/firestorm/bli_cntx_init_firestorm.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ void bli_cntx_init_firestorm( cntx_t* cntx )
cntx,

// level-3
BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armv8a_asm_8x12,
BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armv8a_asm_6x8,
BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_armv8a_asm_12x8r,
BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_armv8a_asm_8x6r,

// packm
BLIS_PACKM_MRXK_KER, BLIS_FLOAT, bli_spackm_armv8a_int_8xk,
BLIS_PACKM_NRXK_KER, BLIS_FLOAT, bli_spackm_armv8a_int_12xk,
BLIS_PACKM_MRXK_KER, BLIS_DOUBLE, bli_dpackm_armv8a_int_6xk,
BLIS_PACKM_NRXK_KER, BLIS_DOUBLE, bli_dpackm_armv8a_int_8xk,
BLIS_PACKM_MRXK_KER, BLIS_FLOAT, bli_spackm_armv8a_int_12xk,
BLIS_PACKM_NRXK_KER, BLIS_FLOAT, bli_spackm_armv8a_int_8xk,
BLIS_PACKM_MRXK_KER, BLIS_DOUBLE, bli_dpackm_armv8a_int_8xk,
BLIS_PACKM_NRXK_KER, BLIS_DOUBLE, bli_dpackm_armv8a_int_6xk,

// gemmsup
BLIS_GEMMSUP_RRR_UKR, BLIS_DOUBLE, bli_dgemmsup_rv_armv8a_asm_6x8m,
Expand All @@ -77,8 +77,8 @@ void bli_cntx_init_firestorm( cntx_t* cntx )
cntx,

// level-3
BLIS_GEMM_UKR_ROW_PREF, BLIS_FLOAT, FALSE,
BLIS_GEMM_UKR_ROW_PREF, BLIS_DOUBLE, FALSE,
BLIS_GEMM_UKR_ROW_PREF, BLIS_FLOAT, TRUE,
BLIS_GEMM_UKR_ROW_PREF, BLIS_DOUBLE, TRUE,

// gemmsup
BLIS_GEMMSUP_RRR_UKR_ROW_PREF, BLIS_DOUBLE, TRUE,
Expand All @@ -95,11 +95,11 @@ void bli_cntx_init_firestorm( cntx_t* cntx )

// Initialize level-3 blocksize objects with architecture-specific values.
// s d c z
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 8, 6, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 8, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 120, 252, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 640, 3072, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 8192, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_MR ], 12, 8, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_NR ], 8, 6, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 256, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_KC ], 4096, 3072, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_NC ], 9600, 8184, -1, -1 );

// Initialize sup thresholds with architecture-appropriate values.
// s d c z
Expand All @@ -110,8 +110,10 @@ void bli_cntx_init_firestorm( cntx_t* cntx )
// Initialize level-3 sup blocksize objects with architecture-specific
// values.
// s d c z
bli_blksz_init_easy( &blkszs[ BLIS_MR_SUP ], -1, 6, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_NR_SUP ], -1, 8, -1, -1 );
bli_blksz_init ( &blkszs[ BLIS_MR_SUP ], -1, 6, -1, -1,
-1, 9, -1, -1 );
bli_blksz_init ( &blkszs[ BLIS_NR_SUP ], -1, 8, -1, -1,
-1, 13, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_MC_SUP ], -1, 240, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_KC_SUP ], -1, 1024, -1, -1 );
bli_blksz_init_easy( &blkszs[ BLIS_NC_SUP ], -1, 3072, -1, -1 );
Expand Down
40 changes: 40 additions & 0 deletions kernels/armv8a/3/armv8a_asm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@
CLEAR4V(V4,V5,V6,V7)

// Scale vectors.
#define SSCALE1V(V,A,IDX) \
" fmul v"#V".4s, v"#V".4s, v"#A".s["#IDX"] \n\t"
#define SSCALE2V(V0,V1,A,IDX) \
SSCALE1V(V0,A,IDX) \
SSCALE1V(V1,A,IDX)
#define SSCALE4V(V0,V1,V2,V3,A,IDX) \
SSCALE2V(V0,V1,A,IDX) \
SSCALE2V(V2,V3,A,IDX)
#define SSCALE8V(V0,V1,V2,V3,V4,V5,V6,V7,A,IDX) \
SSCALE4V(V0,V1,V2,V3,A,IDX) \
SSCALE4V(V4,V5,V6,V7,A,IDX)

#define DSCALE1V(V,A,IDX) \
" fmul v"#V".2d, v"#V".2d, v"#A".d["#IDX"] \n\t"
#define DSCALE2V(V0,V1,A,IDX) \
Expand All @@ -74,6 +86,18 @@
DSCALE4V(V4,V5,V6,V7,A,IDX)

// Scale-accumulate.
#define SSCALEA1V(D,S,A,IDX) \
" fmla v"#D".4s, v"#S".4s, v"#A".s["#IDX"] \n\t"
#define SSCALEA2V(D0,D1,S0,S1,A,IDX) \
SSCALEA1V(D0,S0,A,IDX) \
SSCALEA1V(D1,S1,A,IDX)
#define SSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \
SSCALEA2V(D0,D1,S0,S1,A,IDX) \
SSCALEA2V(D2,D3,S2,S3,A,IDX)
#define SSCALEA8V(D0,D1,D2,D3,D4,D5,D6,D7,S0,S1,S2,S3,S4,S5,S6,S7,A,IDX) \
SSCALEA4V(D0,D1,D2,D3,S0,S1,S2,S3,A,IDX) \
SSCALEA4V(D4,D5,D6,D7,S4,S5,S6,S7,A,IDX)

#define DSCALEA1V(D,S,A,IDX) \
" fmla v"#D".2d, v"#S".2d, v"#A".d["#IDX"] \n\t"
#define DSCALEA2V(D0,D1,S0,S1,A,IDX) \
Expand All @@ -95,8 +119,16 @@
#define DLOAD4V(V0,V1,V2,V3,ADDR,SHIFT) \
DLOAD2V(V0,V1,ADDR,SHIFT) \
DLOAD2V(V2,V3,ADDR,SHIFT+32)
#define SLOAD1V DLOAD1V
#define SLOAD2V DLOAD2V
#define SLOAD4V DLOAD4V

// Generic: load one line.
#define SLOAD1V_GATHER_ELMFWD(V,ADDR,INC) \
" ld1 {v"#V".s}[0], ["#ADDR"], "#INC" \n\t" \
" ld1 {v"#V".s}[1], ["#ADDR"], "#INC" \n\t" \
" ld1 {v"#V".s}[2], ["#ADDR"], "#INC" \n\t" \
" ld1 {v"#V".s}[3], ["#ADDR"], "#INC" \n\t"
#define DLOAD1V_GATHER_ELMFWD(V,ADDR,INC) \
" ld1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \
" ld1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t"
Expand All @@ -110,8 +142,16 @@
#define DSTORE4V(V0,V1,V2,V3,ADDR,SHIFT) \
DSTORE2V(V0,V1,ADDR,SHIFT) \
DSTORE2V(V2,V3,ADDR,SHIFT+32)
#define SSTORE1V DSTORE1V
#define SSTORE2V DSTORE2V
#define SSTORE4V DSTORE4V

// Generic: store one line.
#define SSTORE1V_SCATTER_ELMFWD(V,ADDR,INC) \
" st1 {v"#V".s}[0], ["#ADDR"], "#INC" \n\t" \
" st1 {v"#V".s}[1], ["#ADDR"], "#INC" \n\t" \
" st1 {v"#V".s}[2], ["#ADDR"], "#INC" \n\t" \
" st1 {v"#V".s}[3], ["#ADDR"], "#INC" \n\t"
#define DSTORE1V_SCATTER_ELMFWD(V,ADDR,INC) \
" st1 {v"#V".d}[0], ["#ADDR"], "#INC" \n\t" \
" st1 {v"#V".d}[1], ["#ADDR"], "#INC" \n\t"
Expand Down
Loading

0 comments on commit dfa5413

Please sign in to comment.