Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArmSVE Restore General-Store Case #708

Merged
merged 1 commit into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 38 additions & 39 deletions kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
cntx_t* cntx
)
{
void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data );
const void* a_next = bli_auxinfo_next_a( data );
const void* b_next = bli_auxinfo_next_b( data );

// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
Expand All @@ -68,7 +68,7 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
uint64_t cs_c = cs_c0;
uint64_t info = 0;

GEMM_UKR_SETUP_CT( c, m, 10, false );
GEMM_UKR_SETUP_CT_ANY( c, m, 10, false );

__asm__ volatile (
" whilelo p0.s, xzr, %12 \n\t"
Expand Down Expand Up @@ -117,8 +117,8 @@ BEQ(END_CCOL_PRFM)
GEMM_ACOLCMPLX_CONTIGUOUS_LOAD_FWD(z28,z29,p0,%0,x2)
" \n\t"
LABEL(CCOL_PRFM)
// " cmp %3, #1 \n\t"
// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" cmp %3, #1 \n\t"
BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" mov x16, %2 \n\t"
" prfm PLDL1KEEP, [x16] \n\t"
" add x16, x16, %4 \n\t"
Expand Down Expand Up @@ -232,8 +232,8 @@ MOV_COL2(z8 ,z9 ,z10,z11,z16,z17,z18,z19)
LABEL(WRITE_MEM_EXEC)
" mov x9, %2 \n\t" // C address for loading.
" \n\t" // C address for storing is %2 itself.
// " cmp %3, #1 \n\t"
// BNE(WRITE_MEM_G)
" cmp %3, #1 \n\t"
BNE(WRITE_MEM_G)
" \n\t"
LABEL(WRITE_MEM_C)
" fmov s29, wzr \n\t"
Expand All @@ -259,38 +259,37 @@ LABEL(ZERO_BETA_C_4_5_6_7_8_9)
GEMM_CCMPLX_STORE_COL2_C(z0 ,z1 ,z2 ,z3 ,p0,%2,%4)
GEMM_CCMPLX_STORE_COL2_C(z4 ,z5 ,z6 ,z7 ,p0,%2,%4)
GEMM_CCMPLX_STORE_COL2_C(z8 ,z9 ,z10,z11,p0,%2,%4)
// BRANCH(END_WRITE_MEM)
// " \n\t"
// LABEL(WRITE_MEM_G)
// " add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2,
// " mov x3, %3 \n\t" // s.t. 2*sizeof(float) = 2*4 = 8.
// " index z28.s, wzr, w3 \n\t"
// " fmov s29, wzr \n\t"
// " fcmp s31, #0.0 \n\t" // Whether Imag(beta) == 0.
// " fccmp s30, s29, 0, eq \n\t" // Whether Real(beta) == 0.
// BEQ(ZERO_BETA_G_0_1_2_3)
// GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16)
// GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16)
// GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31)
// GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31)
// LABEL(ZERO_BETA_G_0_1_2_3)
// GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z28,%2,%4,x16)
// GEMM_CCMPLX_STORE_COL2_G(z24,z25,z26,z27,p0,z28,%2,%4,x16)
// " \n\t"
// BEQ(ZERO_BETA_G_4_5_6_7_8_9)
// GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16)
// GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16)
// GEMM_CCMPLX_LOAD_COL2_G(z20,z21,z22,z23,p0,z28,x9,%4,x16)
// GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31)
// GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31)
// GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31)
// LABEL(ZERO_BETA_G_4_5_6_7_8_9)
// GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z28,%2,%4,x16)
// GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z28,%2,%4,x16)
// GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
// " \n\t"
// LABEL(END_WRITE_MEM)
// BRANCH(END_EXEC)
BRANCH(END_WRITE_MEM)
// General-storage case -- Mainly for Column-storage or other aligned cases.
LABEL(WRITE_MEM_G)
" add %3, %3, %3 \n\t" // Skips passed to index is multiplied by 2,
" index z28.s, wzr, %w3 \n\t" // s.t. 2*sizeof(float) = 2*4 = 8.
" fmov s29, wzr \n\t"
" fcmp s31, #0.0 \n\t" // Whether Imag(beta) == 0.
" fccmp s30, s29, 0, eq \n\t" // Whether Real(beta) == 0.
BEQ(ZERO_BETA_G_0_1_2_3)
GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16)
GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16)
GEMM_FMLACMPLX_COL2(z20,z21,z22,z23,p0,z12,z13,z14,z15,z30,z31)
GEMM_FMLACMPLX_COL2(z24,z25,z26,z27,p0,z16,z17,z18,z19,z30,z31)
LABEL(ZERO_BETA_G_0_1_2_3)
GEMM_CCMPLX_STORE_COL2_G(z20,z21,z22,z23,p0,z28,%2,%4,x16)
GEMM_CCMPLX_STORE_COL2_G(z24,z25,z26,z27,p0,z28,%2,%4,x16)
" \n\t"
BEQ(ZERO_BETA_G_4_5_6_7_8_9)
GEMM_CCMPLX_LOAD_COL2_G(z12,z13,z14,z15,p0,z28,x9,%4,x16)
GEMM_CCMPLX_LOAD_COL2_G(z16,z17,z18,z19,p0,z28,x9,%4,x16)
GEMM_CCMPLX_LOAD_COL2_G(z20,z21,z22,z23,p0,z28,x9,%4,x16)
GEMM_FMLACMPLX_COL2(z0 ,z1 ,z2 ,z3 ,p0,z12,z13,z14,z15,z30,z31)
GEMM_FMLACMPLX_COL2(z4 ,z5 ,z6 ,z7 ,p0,z16,z17,z18,z19,z30,z31)
GEMM_FMLACMPLX_COL2(z8 ,z9 ,z10,z11,p0,z20,z21,z22,z23,z30,z31)
LABEL(ZERO_BETA_G_4_5_6_7_8_9)
GEMM_CCMPLX_STORE_COL2_G(z0 ,z1 ,z2 ,z3 ,p0,z28,%2,%4,x16)
GEMM_CCMPLX_STORE_COL2_G(z4 ,z5 ,z6 ,z7 ,p0,z28,%2,%4,x16)
GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
" \n\t"
LABEL(END_WRITE_MEM)
BRANCH(END_EXEC)
" \n\t"
LABEL(END_EXEC)
" mov %11, #0 \n\t" // Return normal.
Expand Down
45 changes: 27 additions & 18 deletions kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
cntx_t* cntx
)
{
void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data );
const void* a_next = bli_auxinfo_next_a( data );
const void* b_next = bli_auxinfo_next_b( data );

// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
Expand All @@ -67,7 +67,7 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;

GEMM_UKR_SETUP_CT( d, m, 10, false );
GEMM_UKR_SETUP_CT_ANY( d, m, 10, false );

__asm__ volatile (
" mov x0, xzr \n\t"
Expand All @@ -82,7 +82,7 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
" mov x3, #10 \n\t" // Row-skip of B.
" \n\t"
" ldr x5, %[c] \n\t"
// " ldr x6, %[rs_c] \n\t" // Row-skip of C.
" ldr x6, %[rs_c] \n\t" // Row-skip of C.
" ldr x7, %[cs_c] \n\t" // Column-skip of C.
#ifdef _A64FX
" mov x8, 0x3 \n\t" // Tag C address.
Expand Down Expand Up @@ -120,8 +120,8 @@ BEQ(END_CCOL_PRFM)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
" \n\t"
LABEL(CCOL_PRFM)
// " cmp x6, #1 \n\t"
// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" cmp x6, #1 \n\t"
BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" mov x16, x5 \n\t"
" prfm PLDL1KEEP, [x16] \n\t"
" add x16, x16, x7 \n\t"
Expand Down Expand Up @@ -256,8 +256,8 @@ LABEL(PREFETCH_ABNEXT)
" \n\t"
" mov x9, x5 \n\t" // C address for loading.
" \n\t" // C address for storing is x5 itself.
// " cmp x6, #1 \n\t" // Preload first half of C for contiguous case.
// BNE(WRITE_MEM)
" cmp x6, #1 \n\t" // Preload first half of C for contiguous case.
BNE(WRITE_MEM)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
" \n\t"
LABEL(WRITE_MEM)
Expand All @@ -268,8 +268,8 @@ BEQ(UNIT_ALPHA)
SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z19,z30)
" \n\t"
LABEL(UNIT_ALPHA)
// " cmp x6, #1 \n\t"
// BNE(WRITE_MEM_G)
" cmp x6, #1 \n\t"
BNE(WRITE_MEM_G)
" \n\t"
LABEL(WRITE_MEM_C)
" \n\t" // Available scratch: Z[20-30].
Expand All @@ -281,17 +281,26 @@ BEQ(BETA_ZERO_C)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
" \n\t"
LABEL(BETA_ZERO_C)
GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p1,x5,x7)
GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p1,x5,x7)
// BRANCH(END_WRITE_MEM)
// " \n\t"
// LABEL(END_WRITE_MEM)
// BRANCH(END_EXEC)
// " \n\t"
// LABEL(END_ERROR)
// " mov x0, #1 \n\t" // Return error.
BRANCH(END_EXEC)
// Generic-storage case -- Mainly for transposed storage.
LABEL(WRITE_MEM_G)
" mov x8, xzr \n\t"
" incb x8 \n\t"
" madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip.
" index z30.d, xzr, x6 \n\t" // Skips passed to index is not multiplied by 8.
" \n\t"
" fcmp d31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN.
BEQ(BETA_ZERO_G)
GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p1,x9,x7,x8,x16)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p1,x9,x7,x8,x16)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
LABEL(BETA_ZERO_G)
GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p1,x5,x7,x8,x16)
GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p1,x5,x7,x8,x16)
LABEL(END_EXEC)
" mov x0, #0 \n\t" // Return normal.
:
Expand Down
41 changes: 25 additions & 16 deletions kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
cntx_t* cntx
)
{
void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data );
const void* a_next = bli_auxinfo_next_a( data );
const void* b_next = bli_auxinfo_next_b( data );

// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
Expand All @@ -67,7 +67,7 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;

GEMM_UKR_SETUP_CT( s, m, 10, false );
GEMM_UKR_SETUP_CT_ANY( s, m, 10, false );

__asm__ volatile (
" mov x0, xzr \n\t"
Expand All @@ -82,7 +82,7 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
" mov x3, #10 \n\t" // Row-skip of B.
" \n\t"
" ldr x5, %[c] \n\t"
// " ldr x6, %[rs_c] \n\t" // Row-skip of C.
" ldr x6, %[rs_c] \n\t" // Row-skip of C.
" ldr x7, %[cs_c] \n\t" // Column-skip of C.
#ifdef _A64FX
" mov x8, 0x3 \n\t" // Tag C address.
Expand Down Expand Up @@ -120,8 +120,8 @@ BEQ(END_CCOL_PRFM)
GEMM_ACOL_CONTIGUOUS_LOAD(z28,z29,p0,p1,x0)
" \n\t"
LABEL(CCOL_PRFM)
// " cmp x6, #1 \n\t"
// BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" cmp x6, #1 \n\t"
BNE(END_CCOL_PRFM) // Do not prefetch for generic C storage.
" mov x16, x5 \n\t"
" prfm PLDL1STRM, [x16] \n\t"
" add x16, x16, x7 \n\t"
Expand Down Expand Up @@ -256,8 +256,8 @@ SCALE_COL20(z0,z1,z2,z3,z4,z5,z6,z7,z8,z9,z10,z11,z12,z13,z14,z15,z16,z17,z18,z1
LABEL(UNIT_ALPHA)
" mov x9, x5 \n\t" // C address for loading.
" \n\t" // C address for storing is x5 itself.
// " cmp x6, #1 \n\t"
// BNE(WRITE_MEM_G)
" cmp x6, #1 \n\t"
BNE(WRITE_MEM_G)
" \n\t"
LABEL(WRITE_MEM_C)
" \n\t" // Available scratch: Z[20-30].
Expand All @@ -268,17 +268,26 @@ GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_C(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,p0,p1,x9,x7)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
" \n\t"
LABEL(BETA_ZERO_C)
GEMM_C_STORE_UKER_C(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,p1,x5,x7)
GEMM_C_STORE_UKER_C(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,p1,x5,x7)
// BRANCH(END_WRITE_MEM)
// " \n\t"
// LABEL(END_WRITE_MEM)
// BRANCH(END_EXEC)
// " \n\t"
// LABEL(END_ERROR)
// " mov x0, #1 \n\t" // Return error.
BRANCH(END_EXEC)
// Generic-storage case -- Mainly for transposed storage.
LABEL(WRITE_MEM_G)
" mov x8, xzr \n\t"
" incb x8 \n\t"
" madd x8, x8, x6, xzr \n\t" // C-column's logical 1-vector skip.
" index z30.s, wzr, w6 \n\t" // Skips passed to index is not multiplied by 8.
" \n\t"
" fcmp s31, #0.0 \n\t" // Skip loading if *beta == 0 to override NaN.
BEQ(BETA_ZERO_G)
GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p1,x9,x7,x8,x16)
GEMM_C_FMLA_UKER(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
GEMM_C_LOAD_UKER_G(z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z30,p0,p1,x9,x7,x8,x16)
GEMM_C_FMLA_UKER(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,p0,z20,z22,z24,z26,z28,z21,z23,z25,z27,z29,z31)
LABEL(BETA_ZERO_G)
GEMM_C_STORE_UKER_G(z0,z2,z4,z6,z8,z1,z3,z5,z7,z9,z30,p0,p1,x5,x7,x8,x16)
GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p1,x5,x7,x8,x16)
LABEL(END_EXEC)
" mov x0, #0 \n\t" // Return normal.
:
Expand Down
Loading