Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu: x64: pool: enable optimized pooling for pad > ur_w #112

Open
wants to merge 1 commit into
base: v2.4_for_ie_master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions src/cpu/x64/jit_uni_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,6 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
jpp.ur_bc = 1;
jpp.ur_bc_tail = 0;
}
auto ur_w = nstl::min(jpp.ow, jpp.ur / jpp.ur_bc);
if (utils::div_up(jpp.l_pad, jpp.stride_w) > ur_w)
return status::unimplemented;
if (utils::div_up(right_pad, jpp.stride_w) > ur_w)
return status::unimplemented;

// scratchpad for c_block slice of input and/or output
using namespace memory_tracking::names;
Expand Down Expand Up @@ -1301,7 +1296,8 @@ void jit_uni_pool_kernel<isa>::generate() {

auto dt_size = jpp.dt_size;
auto shift = (isa == sse41) ? vlen : 0;
add(reg_input, dt_size * (ur_w * stride_w - lpad) * c_off - shift);
add(reg_input,
dt_size * nstl::max(0, ur_w * stride_w - lpad) * c_off - shift);
add(reg_output, dt_size * ur_w * c_off - shift);
if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
auto ishift = (isa == sse41) ? jpp.c_block / 2 : 0;
Expand Down Expand Up @@ -1344,18 +1340,25 @@ void jit_uni_pool_kernel<isa>::generate() {
auto ur_w = nstl::min(jpp.ow, jpp.ur / jpp.ur_bc);
auto ur_w_tail = jpp.ow % ur_w;

int n_oi = ow / ur_w;
const int n_oi_iterations = ow / ur_w;
int n_oi = n_oi_iterations;

int r_pad1
const int r_pad1
= calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w, kw);
if (r_pad1 > 0) n_oi--;
const int ur_stride_w = ur_w * stride_w;
const int l_pad_iterations = utils::div_up(l_pad, ur_stride_w);
const int r_pad_iterations = utils::div_up(r_pad1, ur_stride_w);

if (l_pad > 0) {
n_oi -= nstl::max(0, r_pad_iterations);

for (int i = 0; i < l_pad_iterations; ++i) {
n_oi--;
const int cur_l_pad = l_pad - i * ur_stride_w;
if (n_oi < 0 && r_pad1 > 0)
process_oi(ur_w, ur_bc, l_pad, r_pad1, with_c_tail_processing);
else
process_oi(ur_w, ur_bc, l_pad, 0, with_c_tail_processing);
process_oi(
ur_w, ur_bc, cur_l_pad, r_pad1, with_c_tail_processing);
else if (n_oi >= 0)
process_oi(ur_w, ur_bc, cur_l_pad, 0, with_c_tail_processing);
}

xor_(oi_iter, oi_iter);
Expand All @@ -1371,12 +1374,23 @@ void jit_uni_pool_kernel<isa>::generate() {
}
}

if (r_pad1 > 0 && n_oi >= 0)
process_oi(ur_w, ur_bc, 0, r_pad1, with_c_tail_processing);
if (n_oi >= 0) {
const int r_pad1_tail = r_pad1 % ur_stride_w != 0
? r_pad1 % ur_stride_w
: ur_stride_w;
for (int i = 0; i < r_pad_iterations; ++i) {
const int cur_r_pad = r_pad1_tail + ur_stride_w * i;
process_oi(ur_w, ur_bc, 0, cur_r_pad, with_c_tail_processing);
}
}

if (ur_w_tail != 0)
process_oi(
ur_w_tail, ur_bc, 0, r_pad, with_c_tail_processing, false);
if (ur_w_tail != 0) {
const int l_pad_tail = n_oi_iterations < l_pad_iterations
? l_pad % ur_stride_w
: 0;
process_oi(ur_w_tail, ur_bc, l_pad_tail, r_pad,
with_c_tail_processing, false);
}
};
Label ur_bc_tail_label, c_tail_processing_label, finish_label;

Expand Down
5 changes: 5 additions & 0 deletions tests/benchdnn/inputs/pool/shapes_2d
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ ic35_iw30ih37_ow14oh17_kw3kh4_sw2sh2
ic35_iw30ih36_ow14oh17_pw0ph1_kw3kh4_sw2sh2
ic35_iw33ih37_ow14oh17_kw6kh4_sw2sh2
ic35_iw33ih36_ow14oh17_pw0ph1_kw6kh4_sw2sh2

# Padding is bigger than ur_w
mb1ic8_ih19oh10kh15dh0sh2ph14_iw19ow10kw15dw0sw2pw14
mb1ic8_ih19oh10kh14dh0sh2ph13_iw19ow10kw14dw0sw2pw13

# With dilation
mb1ic8_ih3oh3_kh3ph1_dh2dw2
mb122ic32_ih32iw2_oh32ow2_kh3kw3_ph1pw1_dh4dw1
Expand Down