Skip to content

Commit

Permalink
Revert "x64: brgemm convolution: update req_cal_comp_pad condition"
Browse files Browse the repository at this point in the history
This reverts commit 05d68df.
  • Loading branch information
xczhai committed Jan 2, 2025
1 parent 557d2df commit ec3d689
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 13 deletions.
4 changes: 0 additions & 4 deletions src/cpu/x64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1526,10 +1526,6 @@ status_t brgemm_convolution_fwd_t<isa>::cal_compensation(

const int max_ker_sz = adjusted_k.size();
const auto comp_buffer_ow = jcp.exec_type != exec_vpad ? jcp.ow : 1;
// TODO: revise the thread distribution here because the work_amount may be
// insufficient
// TODO: revise comp_vpad_pbuffer_ generator to avoid huge code for cases
// with big ow
const auto work_amount
= static_cast<dim_t>(jcp.ngroups) * jcp.nb_oc * max_ker_sz;
const auto is_small_shape = work_amount <= jcp.nthr
Expand Down
12 changes: 3 additions & 9 deletions src/cpu/x64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2301,18 +2301,12 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,

// For padding shapes, we calculate the comp along with the computation
// inside brgemm kernel when output size is small to get optimal perf
// For shapes with large ow we calculate the comp inside brgemm kernel too
// because current implementation of brgemm_comp_pad kernel unrolled by ow
// so not optimal for large ow.
// Otherwise we calculate the comp using brgemm_comp_pad kernel
// Or we calculate the comp using brgemm_coomp_pad kernel
const auto output_sz = static_cast<dim_t>(jcp.mb) * jcp.ngroups * jcp.oc
* jcp.od * jcp.oh * jcp.ow;
// TODO: revise below condition to avoid limitation for big ow
const auto shape_for_brgemm_kernel
= (output_sz <= 8192 && jcp.oc < 512) || jcp.ow > 128;
const auto is_relo = jcp.is_relo() && jcp.relo_conv_weights;
jcp.req_brg_comp_pad = compensation_w_padding && jcp.exec_type != exec_trans
&& IMPLICATION(!is_relo, shape_for_brgemm_kernel);
&& IMPLICATION(!(jcp.is_relo() && jcp.relo_conv_weights),
output_sz <= 8192 && jcp.oc < 512);
jcp.req_cal_comp_pad = compensation_w_padding && !jcp.req_brg_comp_pad
&& IMPLICATION(jcp.exec_type == exec_vpad,
jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0
Expand Down

0 comments on commit ec3d689

Please sign in to comment.