Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick binary post optimization #159

Open
wants to merge 3 commits into
base: v2.7_for_ie_master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t(const brgemm_t &abrd)
= {broadcasting_strategy_t::scalar,
broadcasting_strategy_t::per_oc};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(vmm_b().getIdx()), r14, r15, preserve_gpr,
preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec),
GET_OFF(data_C_ptr_), dst_md_wrapper,
static_cast<size_t>(n_vlen_tail()), k_mask,
static_cast<size_t>(vmm_b().getIdx()), r14, r15, r13,
preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
dst_md_wrapper, static_cast<size_t>(n_vlen_tail()), k_mask,
use_exact_tail_scalar_bcast};
const binary_injector::static_params_t bsp {
this->param1, enabled_bcast_strategy, rhs_sp};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator {
broadcasting_strategy_t::no_broadcast};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Xbyak::Zmm(1).getIdx()), this->r14,
this->r15, preserve_gpr, preserve_vmm,
this->r15, this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
dst_md_wrapper, static_cast<size_t>(brg.ldb_tail),
ld_tail_mask, use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
broadcasting_strategy_t::no_broadcast};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Vmm(1).getIdx()), this->r14, this->r15,
preserve_gpr, preserve_vmm,
this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
dst_md_wrapper, static_cast<size_t>(brg.ldb_tail),
ld_tail_mask, use_exact_tail_scalar_bcast};
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/gemm_bf16_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2021 Intel Corporation
* Copyright 2019-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -111,7 +111,7 @@ gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)
static constexpr size_t tail_size = 0;
static constexpr bool use_exact_tail_scalar_bcast = false;
const binary_injector::rhs_arg_static_params_t rhs_sp {
helper_vmm_idx, r13, r14, preserve_gpr,
helper_vmm_idx, r13, r14, r15, preserve_gpr,
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast};
Expand Down
1,009 changes: 615 additions & 394 deletions src/cpu/x64/injectors/jit_uni_binary_injector.cpp

Large diffs are not rendered by default.

209 changes: 94 additions & 115 deletions src/cpu/x64/injectors/jit_uni_binary_injector.hpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32(
const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;

rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14,
preserve_gpr, preserve_vmm,
r15, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx2_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jit_avx2_conv_fwd_kernel_f32::jit_avx2_conv_fwd_kernel_f32(
const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;

rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14,
preserve_gpr, preserve_vmm,
r15, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel(
static constexpr bool use_exact_tail_scalar_bcast = true;

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
r14, r15, preserve_gpr, preserve_vmm,
r14, r15, r12, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, k_load_dim_mask,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_common_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ _jit_avx512_common_conv_fwd_kernel<Vmm>::_jit_avx512_common_conv_fwd_kernel(
static constexpr bool use_exact_tail_scalar_bcast = false;

const binary_injector::rhs_arg_static_params_t rhs_args_static_params {
helper_vmm_idx, reg_tmp, r15, preserve_gpr, preserve_vmm,
helper_vmm_idx, reg_tmp, r15, r14, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, postops_mask,
use_exact_tail_scalar_bcast};
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t(
using namespace binary_injector;
const auto &rhs_addr_reg = bin_injector_helper_reg_1;
const auto &rhs_helper_reg = bin_injector_helper_reg_2;
const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3;
static constexpr bool preserve_gpr = false;
static constexpr bool preserve_vmm = false;
const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
static constexpr bool use_exact_tail_scalar_bcast = true;

const rhs_arg_static_params_t rhs_arg_static_params {31, rhs_addr_reg,
rhs_helper_reg, preserve_gpr, preserve_vmm,
rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, ktail_mask,
use_exact_tail_scalar_bcast};
Expand Down Expand Up @@ -146,7 +147,8 @@ void jit_avx512_core_amx_1x1_fwd_kernel_t::interleave_store() {
const injector_utils::conditional_register_preserve_guard_t
cond_register_guard(jcp.with_binary, this,
{bin_injector_helper_reg_1,
bin_injector_helper_reg_2});
bin_injector_helper_reg_2,
bin_injector_helper_reg_3});
const int wsp_row_offset = jcp.typesize_acc
* (osb * jcp.nb_oc_blocking * jcp.max_width * jcp.oc_block
+ ocb * jcp.max_width * jcp.oc_block
Expand Down
1 change: 1 addition & 0 deletions src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator {

const Xbyak::Reg64 bin_injector_helper_reg_1 = r14;
const Xbyak::Reg64 bin_injector_helper_reg_2 = r15;
const Xbyak::Reg64 bin_injector_helper_reg_3 = r11;

const Xbyak::Opmask ktail_mask = k2;

Expand Down
7 changes: 5 additions & 2 deletions src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,13 +1057,15 @@ jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t(
using namespace binary_injector;
const auto &rhs_addr_reg = bin_injector_helper_reg_1;
const auto &rhs_helper_reg = bin_injector_helper_reg_2;
const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3;
static constexpr bool preserve_gpr = false;
static constexpr bool preserve_vmm = false;
const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
static constexpr bool use_exact_tail_scalar_bcast = true;

const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
31, rhs_addr_reg, rhs_helper_reg, preserve_gpr, preserve_vmm,
31, rhs_addr_reg, rhs_helper_reg, rhs_addr_cache_reg,
preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, ktail_mask,
use_exact_tail_scalar_bcast};
Expand Down Expand Up @@ -1603,7 +1605,8 @@ void jit_avx512_core_amx_fwd_kernel_t::store_output(int width, int tail,
const injector_utils::conditional_register_preserve_guard_t
cond_register_guard(jcp.with_binary, this,
{bin_injector_helper_reg_1,
bin_injector_helper_reg_2});
bin_injector_helper_reg_2,
bin_injector_helper_reg_3});

for (int tw = 0; tw < width && do_store; tw++) {
// height
Expand Down
1 change: 1 addition & 0 deletions src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ struct jit_avx512_core_amx_fwd_kernel_t : public jit_generator {

const Xbyak::Reg64 bin_injector_helper_reg_1 = r14;
const Xbyak::Reg64 bin_injector_helper_reg_2 = r15;
const Xbyak::Reg64 bin_injector_helper_reg_3 = r11;

const Xbyak::Reg64 reg_d_weights = reg_zp_compensation;
const Xbyak::Reg64 reg_d_bias = reg_src_zero_point;
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jit_avx512_core_bf16_1x1_conv_kernel::jit_avx512_core_bf16_1x1_conv_kernel(
static constexpr bool use_exact_tail_scalar_bcast = true;

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
r14, r15, preserve_gpr, preserve_vmm,
r14, r15, r12, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, k_load_dim_tail_mask,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ _jit_avx512_core_bf16_fwd_kernel<Vmm>::_jit_avx512_core_bf16_fwd_kernel(
static constexpr bool use_exact_tail_scalar_bcast = true;

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
r14, r15, preserve_gpr, preserve_vmm,
r14, r15, r12, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, postops_mask,
use_exact_tail_scalar_bcast};
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2021 Intel Corporation
* Copyright 2019-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,7 +46,7 @@ jit_avx512_dw_conv_fwd_kernel_bf16::jit_avx512_dw_conv_fwd_kernel_bf16(
% (cpu_isa_traits<avx512_core>::vlen / sizeof(float));

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
r14, r15, preserve_gpr, preserve_vmm,
r14, r15, r12, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_core_fork_bf16_dw_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ void jit_avx512_fork_dw_conv_fwd_kernel_bf16::generate() {
% (cpu_isa_traits<avx512_core>::vlen / sizeof(float));
static constexpr bool use_exact_tail_scalar_bcast = false;
const binary_injector::rhs_arg_static_params_t rhs_sp {
helper_vmm_idx, r10, r11, preserve_gpr,
helper_vmm_idx, r10, r11, r12, preserve_gpr,
preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec),
GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_),
tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::
static constexpr bool use_exact_tail_scalar_bcast = true;

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
r14, r15, preserve_gpr, preserve_vmm,
r14, r15, r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, postops_mask,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::_jit_avx512_core_x8s8s32x_fwd_kernel(
static constexpr bool use_exact_tail_scalar_bcast = false;

const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
r14, r15, preserve_gpr, preserve_vmm,
r14, r15, r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, postops_mask,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::

const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Xbyak::Xmm(31).getIdx()), this->r14,
this->r15, preserve_gpr, preserve_vmm,
this->r15, this->r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, ktail_mask,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_brgemm_post_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {

const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Xbyak::Zmm(28).getIdx()), this->r14,
this->r15, preserve_gpr, preserve_vmm,
this->r15, this->r13, preserve_gpr, preserve_vmm,
GET_OFF(ptr_binary_post_ops_rhs), GET_OFF(dst_orig),
memory_desc_wrapper(brg.dst_md),
static_cast<size_t>(brg.load_dim % brg.ld_block),
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_gemm_convolution_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
static constexpr size_t tail_size = 0;
static constexpr bool use_exact_tail_scalar_bcast = false;
const binary_injector::rhs_arg_static_params_t rhs_sp {
helper_vmm_idx, r13, r14, preserve_gpr,
helper_vmm_idx, r13, r14, r15, preserve_gpr,
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_gemm_inner_product_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ jit_pp_kernel_t<isa>::jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride,
// for the OC
tail_size = !!tail_size ? tail_size : 1;
const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
helper_vmm_idx, eltwise_reserved_gpr_, r14, preserve_gpr,
helper_vmm_idx, eltwise_reserved_gpr_, r14, r15, preserve_gpr,
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
PARAM_OFF(dst_orig), dst_md_wrapper, tail_size, opmask_binary,
reg_tmp, use_exact_tail_scalar_bcast, prelu_helper_vmm_idx};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_gemm_x8s8s32x_convolution_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct jit_pp_ker_t : pp_ker_t, public jit_generator {
static constexpr size_t tail_size = 0;
static constexpr bool use_exact_tail_scalar_bcast = false;
const binary_injector::rhs_arg_static_params_t rhs_sp {
helper_vmm_idx, r13, r14, preserve_gpr,
helper_vmm_idx, r13, r14, r15, preserve_gpr,
preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
tail_size, kreg_rem_mask_short, use_exact_tail_scalar_bcast};
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2017-2021 Intel Corporation
* Copyright 2017-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,7 +53,7 @@ jit_sse41_1x1_conv_kernel_f32::jit_sse41_1x1_conv_kernel_f32(
static constexpr bool use_exact_tail_scalar_bcast = false;

const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
helper_vmm_idx, r13, r14, preserve_gpr, preserve_vmm,
helper_vmm_idx, r13, r14, r15, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size,
use_exact_tail_scalar_bcast};
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_sse41_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2017-2021 Intel Corporation
* Copyright 2017-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,7 +54,7 @@ jit_sse41_conv_fwd_kernel_f32::jit_sse41_conv_fwd_kernel_f32(
static constexpr bool use_exact_tail_scalar_bcast = false;

const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
helper_vmm_idx, r14, r15, preserve_gpr, preserve_vmm,
helper_vmm_idx, r14, r15, r12, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size,
use_exact_tail_scalar_bcast};
Expand Down
7 changes: 4 additions & 3 deletions src/cpu/x64/jit_uni_binary_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ void jit_uni_binary_kernel_t<isa, Vmm>::init_post_ops_injector() {
reg_elt_inj_table_, elt_inj_opmask_, true /*is_fwd*/,
false /*use_dst*/);
const binary_injector::rhs_arg_static_params_t rhs_arg_bsp {10, reg_tmp_,
reg_elt_inj_table_, true /*preserve gpr*/, true /*preserve vmm*/,
PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), dst_d,
tail_size_, tail_opmask_, false /*use_exact_tail_scalar_bcast*/};
reg_elt_inj_table_, r13, true /*preserve gpr*/,
true /*preserve vmm*/, PARAM_OFF(post_ops_binary_rhs_arg_vec),
PARAM_OFF(dst_orig), dst_d, tail_size_, tail_opmask_,
false /*use_exact_tail_scalar_bcast*/};
const binary_injector::static_params_t bsp(this->param1,
get_supported_postops_bcast_strategies(), rhs_arg_bsp);

Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_uni_dw_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jit_uni_dw_conv_fwd_kernel_f32<isa>::jit_uni_dw_conv_fwd_kernel_f32(
const size_t tail_size = jcp.oc_without_padding
% (cpu_isa_traits<isa>::vlen / sizeof(float));
rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r14, r15,
preserve_gpr, preserve_vmm,
r12, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask,
use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_uni_fork_dw_conv_kernel_f32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ void jit_uni_fork_dw_conv_fwd_kernel_f32<isa>::generate() {
% (cpu_isa_traits<isa>::vlen / sizeof(float));
static constexpr bool use_exact_tail_scalar_bcast = false;
const binary_injector::rhs_arg_static_params_t rhs_sp {
helper_vmm_idx, r10, r11, preserve_gpr,
helper_vmm_idx, r10, r11, r12, preserve_gpr,
preserve_vmm, GET_OFF(post_ops_binary_rhs_arg_vec),
GET_OFF(dst_orig), memory_desc_wrapper(&dst_md_),
tail_size, k_oc_tail_mask, use_exact_tail_scalar_bcast};
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_uni_i8i8_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator {
static constexpr std::size_t tmp_vmm_injector = 0u;

const binary_injector::rhs_arg_static_params_t rhs_sp {
tmp_vmm_injector, r14, r15, preserve_gpr, preserve_vmm,
tmp_vmm_injector, r14, r15, r13, preserve_gpr, preserve_vmm,
GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
memory_desc_wrapper(*dst_md), c_tail_elems,
mask(post_op_tail_opmask_idx_),
Expand Down
Loading