Skip to content

Commit

Permalink
Restored ArmSVE general storage case. (#708)
Browse files Browse the repository at this point in the history
Details:
- Restored general storage case in armsve kernels.
- Reason for doing this: Though real `g`-storage is difficult to 
  speedup, `g`-codepath here can provide a good support for 
  transposed-storage. i.e. at least good for `GEMM_UKR_SETUP_CT_AMBI`.
- By experience, this solution is only *a little* slower than in-reg 
  transpose. Plus in-reg transpose is only possible for a fixed VL in 
  our case.
  • Loading branch information
xrq-phys authored Feb 18, 2023
1 parent 0ba6e9e commit 4e18cd3
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 112 deletions.
77 changes: 38 additions & 39 deletions kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
cntx_t* cntx
)
{
void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data );
const void* a_next = bli_auxinfo_next_a( data );
const void* b_next = bli_auxinfo_next_b( data );

// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
Expand All @@ -68,7 +68,7 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
uint64_t cs_c = cs_c0;
uint64_t info = 0;

GEMM_UKR_SETUP_CT( c, m, 10, false );
GEMM_UKR_SETUP_CT_ANY( c, m, 10, false );

__asm__ volatile (
" whilelo p0.s, xzr, %12 \n\t"
Expand Down Expand Up @@ -117,8 +117,8 @@ BEQ(END_CCOL_PRFM)
GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2)
" \n\t"
LABEL(CCOL_PRFM)
// " cmp %3, #1 \n\t"
// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" cmp %3, #1 \n\t"
BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" mov x16, %2 \n\t"
" prfm PLDL1KEEP, [x16] \n\t"
" add x16, x16, %4 \n\t"
Expand Down Expand Up @@ -232,8 +232,8 @@ MOV_COL2(z8 ,z9 ,z10,z11,z16,z17,z18,z19)
LABEL(WRITE_MEM_EXEC)
" mov x9, %2 \n\t" // C address for loading.
" \n\t" // C address for storing is %2 itself.
// " cmp %3, #1 \n\t"
// BNE(WRITE_MEM_G)
" cmp %3, #1 \n\t"
BNE(WRITE_MEM_G)
" \n\t"
LABEL(WRITE_MEM_C)
" fmov s29, wzr \n\t"
Expand All @@ -259,38 +259,37 @@ LABEL(ZERO_BETA_C_4_5_6_7_8_9)
GEMM_CCMPLX_STORE_COL2_C(z0 ,z1 ,z2 ,z3 ,p0,%2,%4)
GEMM_CCMPLX_STORE_COL2_C(z4 ,z5 ,z6 ,z7 ,p0,%2,%4)
GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4)
// BRANCH(END_WRITE_MEM)
// " \n\t"
// LABEL(WRITE_MEM_G)
// " add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2,
// " mov x3, %3 \n\t" // s.t. 2*sizeof(float) = 2*4 = 8.
// " index z28.s, wzr, w3 \n\t"
// " fmov s29, wzr \n\t"
// " fcmp s31, #0.0 \n\t" // Whether Imag(beta) == 0.
// " fccmp s30, s29, 0, eq \n\t" // Whether Real(beta) == 0.
// BEQ(ZERO_BETA_G_0_1_2_3)
// GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16)
// GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16)
// GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31)
// GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31)
// LABEL(ZERO_BETA_G_0_1_2_3)
// GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z28,%2,%4,x16)
// GEMM_CCMPLX_STORE_COL2_G(z24,z25,z26,z27,p0,z28,%2,%4,x16)
// " \n\t"
// BEQ(ZERO_BETA_G_4_5_6_7_8_9)
// GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16)
// GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16)
// GEMM_CCMPLX_LOAD_COL2_G(z20,z21,z22,z23,p0,z28,x9,%4,x16)
// GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31)
// GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31)
// GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31)
// LABEL(ZERO_BETA_G_4_5_6_7_8_9)
// GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z28,%2,%4,x16)
// GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z28,%2,%4,x16)
// GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
// " \n\t"
// LABEL(END_WRITE_MEM)
// BRANCH(END_EXEC)
BRANCH(END_WRITE_MEM)
// General-storage case -- Mainly for Column-storage or other aligned cases.
LABEL(WRITE_MEM_G)
" add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2,
" index z28.s, wzr, %w3 \n\t" // s.t. 2*sizeof(float) = 2*4 = 8.
" fmov s29, wzr \n\t"
" fcmp s31, #0.0 \n\t" // Whether Imag(beta) == 0.
" fccmp s30, s29, 0, eq \n\t" // Whether Real(beta) == 0.
BEQ(ZERO_BETA_G_0_1_2_3)
GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16)
GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16)
GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31)
GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31)
LABEL(ZERO_BETA_G_0_1_2_3)
GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z28,%2,%4,x16)
GEMM_CCMPLX_STORE_COL2_G(z24,z25,z26,z27,p0,z28,%2,%4,x16)
" \n\t"
BEQ(ZERO_BETA_G_4_5_6_7_8_9)
GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16)
GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16)
GEMM_CCMPLX_LOAD_COL2_G(z20,z21,z22,z23,p0,z28,x9,%4,x16)
GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31)
GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31)
GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31)
LABEL(ZERO_BETA_G_4_5_6_7_8_9)
GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z28,%2,%4,x16)
GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z28,%2,%4,x16)
GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
" \n\t"
LABEL(END_WRITE_MEM)
BRANCH(END_EXEC)
" \n\t"
LABEL(END_EXEC)
" mov %11, #0 \n\t" // Return normal.
Expand Down
45 changes: 27 additions & 18 deletions kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
cntx_t* cntx
)
{
void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data );
const void* a_next = bli_auxinfo_next_a( data );
const void* b_next = bli_auxinfo_next_b( data );

// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
Expand All @@ -67,7 +67,7 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;

GEMM_UKR_SETUP_CT( d, m, 10, false );
GEMM_UKR_SETUP_CT_ANY( d, m, 10, false );

__asm__ volatile (
" mov x0, xzr \n\t"
Expand All @@ -82,7 +82,7 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
" mov x3, #10 \n\t" // Row-skip of B.
" \n\t"
" ldr x5, %[c] \n\t"
// " ldr x6, %[rs_c] \n\t" // Row-skip of C.
" ldr x6, %[rs_c] \n\t" // Row-skip of C.
" ldr x7, %[cs_c] \n\t" // Column-skip of C.
#ifdef _A64FX
" mov x8, 0x3 \n\t" // Tag C address.
Expand Down Expand Up @@ -120,8 +120,8 @@ BEQ(END_CCOL_PRFM)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
" \n\t"
LABEL(CCOL_PRFM)
// " cmp x6, #1 \n\t"
// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" cmp x6, #1 \n\t"
BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" mov x16, x5 \n\t"
" prfm PLDL1KEEP, [x16] \n\t"
" add x16, x16, x7 \n\t"
Expand Down Expand Up @@ -256,8 +256,8 @@ LABEL(PREFETCH_ABNEXT)
" \n\t"
" mov x9, x5 \n\t" // C address for loading.
" \n\t" // C address for storing is x5 itself.
// " cmp x6, #1 \n\t" // Preload first half of C for contiguous case.
// BNE(WRITE_MEM)
" cmp x6, #1 \n\t" // Preload first half of C for contiguous case.
BNE(WRITE_MEM)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
" \n\t"
LABEL(WRITE_MEM)
Expand All @@ -268,8 +268,8 @@ BEQ(UNIT_ALPHA)
SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30)
" \n\t"
LABEL(UNIT_ALPHA)
// " cmp x6, #1 \n\t"
// BNE(WRITE_MEM_G)
" cmp x6, #1 \n\t"
BNE(WRITE_MEM_G)
" \n\t"
LABEL(WRITE_MEM_C)
" \n\t" // Available scratch: Z[20-30].
Expand All @@ -281,17 +281,26 @@ BEQ(BETA_ZERO_C)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
" \n\t"
LABEL(BETA_ZERO_C)
GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p1,x5,x7)
GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p1,x5,x7)
// BRANCH(END_WRITE_MEM)
// " \n\t"
// LABEL(END_WRITE_MEM)
// BRANCH(END_EXEC)
// " \n\t"
// LABEL(END_ERROR)
// " mov x0, #1 \n\t" // Return error.
BRANCH(END_EXEC)
// Generic-storage case -- Mainly for transposed storage.
LABEL(WRITE_MEM_G)
" mov x8, xzr \n\t"
" incb x8 \n\t"
" madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip.
" index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8.
" \n\t"
" fcmp d31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN.
BEQ(BETA_ZERO_G)
GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p1,x9,x7,x8,x16)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p1,x9,x7,x8,x16)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
LABEL(BETA_ZERO_G)
GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p1,x5,x7,x8,x16)
GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p1,x5,x7,x8,x16)
LABEL(END_EXEC)
" mov x0, #0 \n\t" // Return normal.
:
Expand Down
41 changes: 25 additions & 16 deletions kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
cntx_t* cntx
)
{
void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data );
const void* a_next = bli_auxinfo_next_a( data );
const void* b_next = bli_auxinfo_next_b( data );

// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
Expand All @@ -67,7 +67,7 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;

GEMM_UKR_SETUP_CT( s, m, 10, false );
GEMM_UKR_SETUP_CT_ANY( s, m, 10, false );

__asm__ volatile (
" mov x0, xzr \n\t"
Expand All @@ -82,7 +82,7 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
" mov x3, #10 \n\t" // Row-skip of B.
" \n\t"
" ldr x5, %[c] \n\t"
// " ldr x6, %[rs_c] \n\t" // Row-skip of C.
" ldr x6, %[rs_c] \n\t" // Row-skip of C.
" ldr x7, %[cs_c] \n\t" // Column-skip of C.
#ifdef _A64FX
" mov x8, 0x3 \n\t" // Tag C address.
Expand Down Expand Up @@ -120,8 +120,8 @@ BEQ(END_CCOL_PRFM)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
" \n\t"
LABEL(CCOL_PRFM)
// " cmp x6, #1 \n\t"
// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" cmp x6, #1 \n\t"
BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" mov x16, x5 \n\t"
" prfm PLDL1STRM, [x16] \n\t"
" add x16, x16, x7 \n\t"
Expand Down Expand Up @@ -256,8 +256,8 @@ SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z1
LABEL(UNIT_ALPHA)
" mov x9, x5 \n\t" // C address for loading.
" \n\t" // C address for storing is x5 itself.
// " cmp x6, #1 \n\t"
// BNE(WRITE_MEM_G)
" cmp x6, #1 \n\t"
BNE(WRITE_MEM_G)
" \n\t"
LABEL(WRITE_MEM_C)
" \n\t" // Available scratch: Z[20-30].
Expand All @@ -268,17 +268,26 @@ GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
" \n\t"
LABEL(BETA_ZERO_C)
GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p1,x5,x7)
GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p1,x5,x7)
// BRANCH(END_WRITE_MEM)
// " \n\t"
// LABEL(END_WRITE_MEM)
// BRANCH(END_EXEC)
// " \n\t"
// LABEL(END_ERROR)
// " mov x0, #1 \n\t" // Return error.
BRANCH(END_EXEC)
// Generic-storage case -- Mainly for transposed storage.
LABEL(WRITE_MEM_G)
" mov x8, xzr \n\t"
" incb x8 \n\t"
" madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip.
" index z30.s, wzr, w6 \n\t" // Skips passed to index is not multiplied by 8.
" \n\t"
" fcmp s31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN.
BEQ(BETA_ZERO_G)
GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p1,x9,x7,x8,x16)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p1,x9,x7,x8,x16)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
LABEL(BETA_ZERO_G)
GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p1,x5,x7,x8,x16)
GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p1,x5,x7,x8,x16)
LABEL(END_EXEC)
" mov x0, #0 \n\t" // Return normal.
:
Expand Down
Loading

0 comments on commit 4e18cd3

Please sign in to comment.