Skip to content

Commit

Permalink
gemm_convolution: memory access fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alexey-varyzgin committed Jan 31, 2022
1 parent f941509 commit 180df0c
Showing 1 changed file with 86 additions and 9 deletions.
95 changes: 86 additions & 9 deletions src/cpu/x64/jit_gemm_convolution_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {

private:
void generate() override;
void copy_elems(const Xbyak::Reg64 &dst, const Xbyak::Reg64 &src, const Xbyak::Reg64 &size, const int elemSize);
void foreach (const Xbyak::Reg64 &idx, size_t step, const Xbyak::Reg64 &end, std::function<void(const Xbyak::Reg64 &)> && fn);

struct ker_args_t {
float *dst;
Expand Down Expand Up @@ -142,6 +144,47 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
Vmm vreg_bias(int idx) { return Vmm(idx_vreg_bias(idx)); };
};

template <cpu_isa_t isa>
void jit_pp_kernel_t<isa>::foreach (const Xbyak::Reg64 &idx, size_t step,
const Xbyak::Reg64 &end, std::function<void(const Xbyak::Reg64&)> && fn)
{
Xbyak::Label loop, exit;

L(loop);
cmp(idx, end);
jge(exit);

fn(idx);

add(idx, step);
jmp(loop);
L(exit);
}

template <cpu_isa_t isa>
void jit_pp_kernel_t<isa>::copy_elems(const Xbyak::Reg64 &dst,
const Xbyak::Reg64& src, const Xbyak::Reg64& size, const int elemSize) {
push(rsi);
push(r13);

xor_(rsi, rsi);

if (elemSize == 1) {
foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) {
mov(r13b, byte[src + idx * elemSize]);
mov(byte[dst + idx * elemSize], r13b);
});
} else if (elemSize == 4) {
foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) {
mov(r13d, dword[src + idx * elemSize]);
mov(dword[dst + idx * elemSize], r13d);
});
}

pop(r13);
pop(rsi);
}

template <cpu_isa_t isa>
void jit_pp_kernel_t<isa>::generate() {
using namespace Xbyak;
Expand All @@ -161,7 +204,18 @@ void jit_pp_kernel_t<isa>::generate() {
mov(reg_table, l_table);
}

auto apply_post_ops = [&]() {
auto store_to_stack = [&](const Reg64 &from, const Reg64 &size) {
sub(rsp, vlen * sizeof(float));
mov(r8, rsp);
copy_elems(r8, from, size, sizeof(float));
};

auto load_from_stack = [&](const Vmm &to) {
uni_vmovups(to, ptr[rsp]);
add(rsp, vlen * sizeof(float));
};

auto apply_post_ops = [&](bool apply_mask) {
int eltwise_inj_idx = 0;
int depthwise_inj_idx = 0;
auto vreg_dst_ = vreg_dst(0);
Expand All @@ -176,8 +230,20 @@ void jit_pp_kernel_t<isa>::generate() {
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
lea(reg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float)]);
lea(reg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float)]);
if (apply_mask) {
store_to_stack(reg_d_weights, reg_tmp);
mov(reg_d_weights, rsp);

if (post_op.depthwise.alg == dnnl_depthwise_scale_shift) {
store_to_stack(reg_d_bias, reg_tmp);
mov(reg_d_bias, rsp);
}
}
jit_depthwise_injectors_[depthwise_inj_idx]->compute_vector_range(vreg_dst_.getIdx(), vreg_dst_.getIdx() + 1,
reg_d_weights, reg_d_bias, true);
if (apply_mask) {
add(rsp, (post_op.depthwise.alg == dnnl_depthwise_scale_shift ? 2 : 1) * vlen * sizeof(float));
}
depthwise_inj_idx++;
} else if (post_op.is_quantization()) {
bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize;
Expand Down Expand Up @@ -243,6 +309,10 @@ void jit_pp_kernel_t<isa>::generate() {
// Load accumulated value, convert to float, apply bias (if any), scaling,
// and eltwise (if any); then convert to destination type and store
auto compute = [&](bool apply_mask) {
if (apply_mask) {
push(r8);
}

auto dst_addr = ptr[reg_dst];
auto vreg_dst_ = vreg_dst(0);
if (isa == avx512_common) {
Expand All @@ -251,11 +321,8 @@ void jit_pp_kernel_t<isa>::generate() {
uni_vmovups(vreg_dst_, dst_addr);
} else {
if (apply_mask) {
if (isa != sse41) {
uni_vblendvps(vreg_dst_, vreg_zero, dst_addr, vreg_mask);
} else {
uni_vmovups(vreg_dst_, dst_addr);
}
store_to_stack(reg_dst, reg_tmp);
load_from_stack(vreg_dst_);
} else {
uni_vmovups(vreg_dst_, dst_addr);
}
Expand All @@ -270,7 +337,7 @@ void jit_pp_kernel_t<isa>::generate() {
uni_vaddps(vreg_dst_, vreg_dst_, vreg_bias_);
}

apply_post_ops();
apply_post_ops(apply_mask);

if (isa == avx512_common) {
uni_vmovups(dst_addr, vreg_dst_);
Expand All @@ -279,13 +346,20 @@ void jit_pp_kernel_t<isa>::generate() {
if (isa != sse41) {
vmaskmovps(dst_addr, vreg_mask, vreg_dst_);
} else {
lea(reg_ptr_maskmovdqu_dst, dst_addr);
maskmovdqu(vreg_dst_, vreg_mask);
sub(rsp, vlen * sizeof(float));
mov(r8, rsp);
uni_vmovups(ptr[r8], vreg_dst_);
copy_elems(reg_dst, r8, reg_tmp, sizeof(float));
add(rsp, vlen * sizeof(float));
}
} else {
uni_vmovups(dst_addr, vreg_dst_);
}
}

if (apply_mask) {
pop(r8);
}
};

Label loop_end;
Expand All @@ -303,6 +377,9 @@ void jit_pp_kernel_t<isa>::generate() {
cmp(reg_len, vlen);
jge(loop, T_NEAR);
}

cmp(reg_tmp, 0);
je(loop_end, T_NEAR);

L(loop_tail);
mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
Expand Down

0 comments on commit 180df0c

Please sign in to comment.