Skip to content

Commit

Permalink
remaining inch for winograd neon3
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jan 31, 2018
1 parent c12fab5 commit c6506d6
Showing 1 changed file with 349 additions and 0 deletions.
349 changes: 349 additions & 0 deletions src/layer/arm/convolution_3x3.h
Original file line number Diff line number Diff line change
Expand Up @@ -4019,6 +4019,355 @@ static void conv3x3s1_winograd64_neon3(const Mat& bottom_blob, Mat& top_blob, co
k01 += 8;
k10 += 8;
k11 += 8;
#endif // __ARM_NEON
}
}

for (; q<inch; q++)
{
const float* r0 = bottom_blob_tm.channel(q);

const float* k00 = kernel0_tm.row(q);
const float* k10 = kernel1_tm.row(q);

float* output0_tm = out0_tm;
float* output1_tm = out1_tm;

for (int r=0; r<8; r++)
{
#if __ARM_NEON
#if __aarch64__
float32x4_t _k00 = vld1q_f32(k00);
float32x4_t _k00n = vld1q_f32(k00+4);
float32x4_t _k10 = vld1q_f32(k10);
float32x4_t _k10n = vld1q_f32(k10+4);
#else
float32x4_t _k00;
float32x4_t _k00n;
float32x4_t _k10;
float32x4_t _k10n;

asm volatile(
"pld [%0, #256] \n"
"vld1.f32 {%e2-%f2}, [%0 :128]! \n"
"pld [%1, #256] \n"
"vld1.f32 {%e4-%f4}, [%1 :128]! \n"
"vld1.f32 {%e3-%f3}, [%0 :128]! \n"
"vld1.f32 {%e5-%f5}, [%1 :128]! \n"
: "=r"(k00), // %0
"=r"(k10), // %1
"=w"(_k00), // %2
"=w"(_k00n), // %3
"=w"(_k10), // %4
"=w"(_k10n) // %5
: "0"(k00),
"1"(k10)
: "cc", "memory"
);
#endif // __aarch64__
#endif // __ARM_NEON

// tile
#if __ARM_NEON
int nn = tiles >> 2;
int remain = tiles & 3;
#else
int remain = tiles;
#endif // __ARM_NEON

#if __ARM_NEON
#if __aarch64__
for (; nn>0; nn--)
{
float32x4_t _output0_tm = vld1q_f32(output0_tm);
float32x4_t _output0_tmn = vld1q_f32(output0_tm+4);

float32x4_t _output1_tm = vld1q_f32(output1_tm);
float32x4_t _output1_tmn = vld1q_f32(output1_tm+4);

float32x4_t _r0 = vld1q_f32(r0);
float32x4_t _r0n = vld1q_f32(r0+4);

r0 += 8;

_output0_tm = vmlaq_f32(_output0_tm, _r0, _k00);
_output0_tmn = vmlaq_f32(_output0_tmn, _r0n, _k00n);

_output1_tm = vmlaq_f32(_output1_tm, _r0, _k10);
_output1_tmn = vmlaq_f32(_output1_tmn, _r0n, _k10n);

vst1q_f32(output0_tm, _output0_tm);
vst1q_f32(output0_tm+4, _output0_tmn);

vst1q_f32(output1_tm, _output1_tm);
vst1q_f32(output1_tm+4, _output1_tmn);

output0_tm += 8;
output1_tm += 8;

_output0_tm = vld1q_f32(output0_tm);
_output0_tmn = vld1q_f32(output0_tm+4);

_output1_tm = vld1q_f32(output1_tm);
_output1_tmn = vld1q_f32(output1_tm+4);

_r0 = vld1q_f32(r0);
_r0n = vld1q_f32(r0+4);

r0 += 8;

_output0_tm = vmlaq_f32(_output0_tm, _r0, _k00);
_output0_tmn = vmlaq_f32(_output0_tmn, _r0n, _k00n);

_output1_tm = vmlaq_f32(_output1_tm, _r0, _k10);
_output1_tmn = vmlaq_f32(_output1_tmn, _r0n, _k10n);

vst1q_f32(output0_tm, _output0_tm);
vst1q_f32(output0_tm+4, _output0_tmn);

vst1q_f32(output1_tm, _output1_tm);
vst1q_f32(output1_tm+4, _output1_tmn);

output0_tm += 8;
output1_tm += 8;

_output0_tm = vld1q_f32(output0_tm);
_output0_tmn = vld1q_f32(output0_tm+4);

_output1_tm = vld1q_f32(output1_tm);
_output1_tmn = vld1q_f32(output1_tm+4);

_r0 = vld1q_f32(r0);
_r0n = vld1q_f32(r0+4);

r0 += 8;

_output0_tm = vmlaq_f32(_output0_tm, _r0, _k00);
_output0_tmn = vmlaq_f32(_output0_tmn, _r0n, _k00n);

_output1_tm = vmlaq_f32(_output1_tm, _r0, _k10);
_output1_tmn = vmlaq_f32(_output1_tmn, _r0n, _k10n);

vst1q_f32(output0_tm, _output0_tm);
vst1q_f32(output0_tm+4, _output0_tmn);

vst1q_f32(output1_tm, _output1_tm);
vst1q_f32(output1_tm+4, _output1_tmn);

output0_tm += 8;
output1_tm += 8;

_output0_tm = vld1q_f32(output0_tm);
_output0_tmn = vld1q_f32(output0_tm+4);

_output1_tm = vld1q_f32(output1_tm);
_output1_tmn = vld1q_f32(output1_tm+4);

_r0 = vld1q_f32(r0);
_r0n = vld1q_f32(r0+4);

r0 += 8;

_output0_tm = vmlaq_f32(_output0_tm, _r0, _k00);
_output0_tmn = vmlaq_f32(_output0_tmn, _r0n, _k00n);

_output1_tm = vmlaq_f32(_output1_tm, _r0, _k10);
_output1_tmn = vmlaq_f32(_output1_tmn, _r0n, _k10n);

vst1q_f32(output0_tm, _output0_tm);
vst1q_f32(output0_tm+4, _output0_tmn);

vst1q_f32(output1_tm, _output1_tm);
vst1q_f32(output1_tm+4, _output1_tmn);

output0_tm += 8;
output1_tm += 8;
}
#else
if (nn > 0)
{
asm volatile(
"0: \n"

"pld [%3, #256] \n"
"vld1.f32 {d24-d27}, [%3 :128]! \n"// q12 q13 = _r0

"pld [%1, #256] \n"
"vld1.f32 {d16-d19}, [%1 :128] \n"// q8 q9 = _output0_tm

"vmla.f32 q8, q12, %q8 \n"
"vmla.f32 q9, q13, %q9 \n"

"pld [%2, #256] \n"
"vld1.f32 {d20-d23}, [%2 :128] \n"// q10 q11 = _output1_tm

"vmla.f32 q10, q12, %q10 \n"
"vmla.f32 q11, q13, %q11 \n"

"pld [%3, #256] \n"
"vld1.f32 {d24-d27}, [%3 :128]! \n"// q12 q13 = _r0

"vst1.f32 {d16-d19}, [%1 :128]! \n"

"pld [%1, #256] \n"
"vld1.f32 {d16-d19}, [%1 :128] \n"// q8 q9 = _output0_tm

"vmla.f32 q8, q12, %q8 \n"
"vmla.f32 q9, q13, %q9 \n"

"vst1.f32 {d20-d23}, [%2 :128]! \n"

"pld [%2, #256] \n"
"vld1.f32 {d20-d23}, [%2 :128] \n"// q10 q11 = _output1_tm

"vmla.f32 q10, q12, %q10 \n"
"vmla.f32 q11, q13, %q11 \n"

"pld [%3, #256] \n"
"vld1.f32 {d24-d27}, [%3 :128]! \n"// q12 q13 = _r0

"vst1.f32 {d16-d19}, [%1 :128]! \n"

"pld [%1, #256] \n"
"vld1.f32 {d16-d19}, [%1 :128] \n"// q8 q9 = _output0_tm

"vmla.f32 q8, q12, %q8 \n"
"vmla.f32 q9, q13, %q9 \n"

"vst1.f32 {d20-d23}, [%2 :128]! \n"

"pld [%2, #256] \n"
"vld1.f32 {d20-d23}, [%2 :128] \n"// q10 q11 = _output1_tm

"vmla.f32 q10, q12, %q10 \n"
"vmla.f32 q11, q13, %q11 \n"

"pld [%3, #256] \n"
"vld1.f32 {d24-d27}, [%3 :128]! \n"// q12 q13 = _r0

"vst1.f32 {d16-d19}, [%1 :128]! \n"

"pld [%1, #256] \n"
"vld1.f32 {d16-d19}, [%1 :128] \n"// q8 q9 = _output0_tm

"vmla.f32 q8, q12, %q8 \n"
"vmla.f32 q9, q13, %q9 \n"

"vst1.f32 {d20-d23}, [%2 :128]! \n"

"pld [%2, #256] \n"
"vld1.f32 {d20-d23}, [%2 :128] \n"// q10 q11 = _output1_tm

"vmla.f32 q10, q12, %q10 \n"
"vmla.f32 q11, q13, %q11 \n"

"vst1.f32 {d16-d19}, [%1 :128]! \n"
"vst1.f32 {d20-d23}, [%2 :128]! \n"

"subs %0, #1 \n"

"bne 0b \n"

: "=r"(nn), // %0
"=r"(output0_tm), // %1
"=r"(output1_tm), // %2
"=r"(r0) // %3
: "0"(nn),
"1"(output0_tm),
"2"(output1_tm),
"3"(r0),
"w"(_k00), // %8
"w"(_k00n), // %9
"w"(_k10), // %10
"w"(_k10n) // %11
: "cc", "memory", "q8", "q9", "q10", "q11", "q12", "q13"
);
}
#endif // __aarch64__
#endif // __ARM_NEON
for (; remain>0; remain--)
{
#if __ARM_NEON
#if __aarch64__
float32x4_t _output0_tm = vld1q_f32(output0_tm);
float32x4_t _output0_tmn = vld1q_f32(output0_tm+4);

float32x4_t _output1_tm = vld1q_f32(output1_tm);
float32x4_t _output1_tmn = vld1q_f32(output1_tm+4);

float32x4_t _r0 = vld1q_f32(r0);
float32x4_t _r0n = vld1q_f32(r0+4);

r0 += 8;

_output0_tm = vmlaq_f32(_output0_tm, _r0, _k00);
_output0_tmn = vmlaq_f32(_output0_tmn, _r0n, _k00n);

_output1_tm = vmlaq_f32(_output1_tm, _r0, _k10);
_output1_tmn = vmlaq_f32(_output1_tmn, _r0n, _k10n);

vst1q_f32(output0_tm, _output0_tm);
vst1q_f32(output0_tm+4, _output0_tmn);

vst1q_f32(output1_tm, _output1_tm);
vst1q_f32(output1_tm+4, _output1_tmn);

output0_tm += 8;
output1_tm += 8;
#else
asm volatile(
"pld [%2, #256] \n"
"vld1.f32 {d24-d27}, [%2 :128]! \n"// q12 q13 = _r0

"pld [%0, #256] \n"
"vld1.f32 {d16-d19}, [%0 :128] \n"// q8 q9 = _output0_tm

"vmla.f32 q8, q12, %q6 \n"
"vmla.f32 q9, q13, %q7 \n"

"pld [%1, #256] \n"
"vld1.f32 {d20-d23}, [%1 :128] \n"// q10 q11 = _output1_tm

"vmla.f32 q10, q12, %q8 \n"
"vmla.f32 q11, q13, %q9 \n"

"vst1.f32 {d16-d19}, [%0 :128]! \n"
"vst1.f32 {d20-d23}, [%1 :128]! \n"

: "=r"(output0_tm), // %0
"=r"(output1_tm), // %1
"=r"(r0) // %2
: "0"(output0_tm),
"1"(output1_tm),
"2"(r0),
"w"(_k00), // %6
"w"(_k00n), // %7
"w"(_k10), // %8
"w"(_k10n) // %9
: "cc", "memory", "q8", "q9", "q10", "q11", "q12", "q13"
);
#endif // __aarch64__
#else
for (int m=0; m<8; m++)
{
output0_tm[m] += r0[m] * k00[m];
output1_tm[m] += r0[m] * k10[m];
}

r0 += 8;
output0_tm += 8;
output1_tm += 8;
#endif // __ARM_NEON
}

#if __ARM_NEON
#if __aarch64__
k00 += 8;
k10 += 8;
#endif // __aarch64__
#else
k00 += 8;
k10 += 8;
#endif // __ARM_NEON
}
}
Expand Down

0 comments on commit c6506d6

Please sign in to comment.