Skip to content

Commit

Permalink
Added 3D DW case support for JIT INT8 Convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed Dec 14, 2021
1 parent e946756 commit 10fcf5d
Show file tree
Hide file tree
Showing 13 changed files with 265 additions and 16 deletions.
2 changes: 2 additions & 0 deletions include/mkldnn_dnnl_mangling.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
#define mkldnn_Abcde16a dnnl_Abcde16a
#define mkldnn_Abcde4a dnnl_Abcde4a
#define mkldnn_Abcde8a dnnl_Abcde8a
#define mkldnn_Abcdef4a dnnl_Abcdef4a
#define mkldnn_Abcdef8a dnnl_Abcdef8a
#define mkldnn_Abcdef16a dnnl_Abcdef16a
#define mkldnn_Acb16a dnnl_Acb16a
Expand All @@ -164,6 +165,7 @@
#define mkldnn_BAcde16b16a dnnl_BAcde16b16a
#define mkldnn_BAcde16a16b dnnl_BAcde16a16b
#define mkldnn_BAcde8a16b2a dnnl_BAcde8a16b2a
#define mkldnn_Goidhw4g dnnl_Goidhw4g
#define mkldnn_Goidhw8g dnnl_Goidhw8g
#define mkldnn_Goidhw16g dnnl_Goidhw16g
#define mkldnn_Goihw16g dnnl_Goihw16g
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,7 @@ struct memory : public handle<dnnl_memory_t> {
BAcde16b16a = dnnl_BAcde16b16a,
BAcde16a16b = dnnl_BAcde16a16b,
aBdec32b = dnnl_aBdec32b,
Abcdef4a = dnnl_Abcdef4a,
Abcdef8a = dnnl_Abcdef8a,
Abcdef16a = dnnl_Abcdef16a,
Abcdef32a = dnnl_Abcdef32a,
Expand Down Expand Up @@ -1873,6 +1874,7 @@ struct memory : public handle<dnnl_memory_t> {
IOdhw16i16o = dnnl_IOdhw16i16o,
gIOhw16i16o = dnnl_gIOhw16i16o,
gOhwi32o = dnnl_gOhwi32o,
Goidhw4g = dnnl_Goidhw4g,
Goidhw8g = dnnl_Goidhw8g,
Goidhw16g = dnnl_Goidhw16g,
IOw16o16i = dnnl_IOw16o16i,
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ typedef enum {
dnnl_aCBdef16c16b,
dnnl_aBdefc4b,
dnnl_aBdefc8b,
dnnl_Abcdef4a,
dnnl_Abcdef8a,
dnnl_Abcdef16a,
dnnl_Abcdef32a,
Expand Down Expand Up @@ -1099,6 +1100,7 @@ typedef enum {
dnnl_gIOdhw8o16i2o = dnnl_aCBdef8b16c2b,
dnnl_gOIdhw8o8i = dnnl_aBCdef8b8c,
dnnl_gOIdhw8o4i = dnnl_aBCdef8b4c,
dnnl_Goidhw4g = dnnl_Abcdef4a,
dnnl_Goidhw8g = dnnl_Abcdef8a,
dnnl_Goidhw16g = dnnl_Abcdef16a,
dnnl_Goidhw32g = dnnl_Abcdef32a,
Expand Down
2 changes: 2 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ const format_tag_t ABcd40a32b = dnnl_ABcd40a32b;
const format_tag_t ABcde40a32b = dnnl_ABcde40a32b;
const format_tag_t BAcde16b16a = dnnl_BAcde16b16a;
const format_tag_t aBdec32b = dnnl_aBdec32b;
const format_tag_t Abcdef4a = dnnl_Abcdef4a;
const format_tag_t Abcdef8a = dnnl_Abcdef8a;
const format_tag_t Abcdef16a = dnnl_Abcdef16a;
const format_tag_t Abcdef32a = dnnl_Abcdef32a;
Expand Down Expand Up @@ -758,6 +759,7 @@ const format_tag_t IOhw16i16o = dnnl_IOhw16i16o;
const format_tag_t Ohwi32o = dnnl_Ohwi32o;
const format_tag_t gIOhw16i16o = dnnl_gIOhw16i16o;
const format_tag_t gOhwi32o = dnnl_gOhwi32o;
const format_tag_t Goidhw4g = dnnl_Goidhw4g;
const format_tag_t Goidhw8g = dnnl_Goidhw8g;
const format_tag_t Goidhw16g = dnnl_Goidhw16g;
const format_tag_t IOw16o16i = dnnl_IOw16o16i;
Expand Down
2 changes: 2 additions & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_aCBdef16c16b) return "aCBdef16c16b";
if (v == dnnl_aBdefc4b) return "aBdefc4b";
if (v == dnnl_aBdefc8b) return "aBdefc8b";
if (v == dnnl_Abcdef4a) return "Abcdef4a";
if (v == dnnl_Abcdef8a) return "Abcdef8a";
if (v == dnnl_Abcdef16a) return "Abcdef16a";
if (v == dnnl_Abcdef32a) return "Abcdef32a";
Expand Down Expand Up @@ -835,6 +836,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_gIOdhw8o16i2o) return "gIOdhw8o16i2o";
if (v == dnnl_gOIdhw8o8i) return "gOIdhw8o8i";
if (v == dnnl_gOIdhw8o4i) return "gOIdhw8o4i";
if (v == dnnl_Goidhw4g) return "Goidhw4g";
if (v == dnnl_Goidhw8g) return "Goidhw8g";
if (v == dnnl_Goidhw16g) return "Goidhw16g";
if (v == dnnl_Goidhw32g) return "Goidhw32g";
Expand Down
1 change: 1 addition & 0 deletions src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ status_t memory_desc_wrapper::compute_blocking(
C(aBdec32b, {0, 1, 3, 4, 2}, {32}, {1});
C(aCBdef16c16b, {0, 2, 1, 3, 4, 5}, {16, 16}, {2, 1});
C(aCBdef16b16c, {0, 2, 1, 3, 4, 5}, {16, 16}, {1, 2});
C(Abcdef4a, {0, 1, 2, 3, 4, 5}, {4}, {0});
C(Abcdef8a, {0, 1, 2, 3, 4, 5}, {8}, {0});
C(Abcdef16a, {0, 1, 2, 3, 4, 5}, {16}, {0});
C(Abcdef32a, {0, 1, 2, 3, 4, 5}, {32}, {0});
Expand Down
1 change: 1 addition & 0 deletions src/common/tag_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ DECL_TRAITS(aBCde2b8c8b2c, _BC, _2b8c8b2c, 5);
DECL_TRAITS(aBdec32b, _B, _32b, 5);
DECL_TRAITS(aCBdef16c16b, _BC, _16c16b, 6);
DECL_TRAITS(aCBdef16b16c, _BC, _16b16c, 6);
DECL_TRAITS(Abcdef4a, _A, _4a, 6);
DECL_TRAITS(Abcdef8a, _A, _8a, 6);
DECL_TRAITS(Abcdef16a, _A, _16a, 6);
DECL_TRAITS(aCBd16c16b, _BC, _16c16b, 4);
Expand Down
11 changes: 4 additions & 7 deletions src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void pick_loop_order(jit_conv_conf_t &jcp, int nthr) {
jcp.loop_order = loop_cwgn;
if (jcp.ngroups > 1) {
jcp.loop_order = loop_ngcw;
if (jcp.mb < nthr)
if (jcp.mb < nthr && jcp.ndims != 5)
jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
} else if (jcp.mb >= nthr && jcp.ic_without_padding <= 16) {
jcp.loop_order = loop_ngcw;
Expand Down Expand Up @@ -415,7 +415,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(int ur_w,
};

auto kernel_offset = [=](int ci, int ki) {
return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
return jcp.typesize_in * ((ci * jcp.kd * jcp.kh * jcp.kw + ki) * jcp.ch_block);
};

auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
Expand Down Expand Up @@ -1384,10 +1384,6 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
jcp.need_saturation = utils::one_of(dst_d.data_type(), u8, s8, s32);
jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);

if (jcp.is_depthwise && is_3d)
// NOTE: 3D depthwise is not currently supported here.
return status::unimplemented;

jcp.with_input_zp = !attr.input_zero_points_.has_default_values();
jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values();

Expand Down Expand Up @@ -1470,7 +1466,8 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
format_tag_t wei_tag;
if (jcp.ic_block == 16 || jcp.ch_block == 16) {
if (is_3d) {
wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i;
wei_tag = with_groups ? jcp.is_depthwise ? Goidhw16g : gOIdhw4i16o4i
: OIdhw4i16o4i;
} else if (is_1d) {
wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i
: OIw4i16o4i;
Expand Down
115 changes: 115 additions & 0 deletions src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,121 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d(
return status::success;
}

status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d_dw(const exec_ctx_t &ctx) const {
auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);

const memory_desc_wrapper src_d(pd()->src_md());
const memory_desc_wrapper dst_d(pd()->dst_md());
const memory_desc_wrapper weights_d(pd()->weights_md(0));
const memory_desc_wrapper bias_d(pd()->weights_md(1));

const size_t bia_dt_size
= pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
const size_t dst_dt_size = types::data_type_size(dst_d.data_type());

const auto &jcp = pd()->jcp_;
assert(jcp.ic_block == 1);
assert(jcp.oc_block == 1);
assert(jcp.nb_ic == 1);
assert(jcp.nb_oc == 1);
assert(jcp.nb_oc_blocking == 1);
assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);

const float *oscales = pd()->attr()->output_scales_.scales_;
if (jcp.signed_input && jcp.ver != ver_vnni) {
auto local_scales = ctx.get_scratchpad_grantor().template get<float>(
key_conv_adjusted_scales);
size_t count = pd()->attr()->output_scales_.count_;
float factor = 1.f / pd()->jcp_.wei_adj_scale;
if (count == 1) {
utils::array_set(local_scales, oscales[0] * factor, 16);
} else {
for (size_t c = 0; c < count; c++)
local_scales[c] = oscales[c] * factor;
}
oscales = local_scales;
}

size_t offset = weights_d.size() - weights_d.additional_buffer_size();
auto w = const_cast<char *>(weights);
int32_t* compensation = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) :
(jcp.with_input_zp) ? pd()->attr()->output_compensations_.shifts_ : 0;
const uint8_t* input_zp = pd()->attr()->input_zero_points_.shifts_;
int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
int group_block = jcp.ch_block;

parallel_nd(jcp.mb, jcp.od, jcp.oh, jcp.nb_ow, nb_groups, [&](int n, int od_s, int oh_s, int owb, int gg) {
auto p = jit_conv_call_s();

size_t src_d_stride = src_d.blk_off(0, 0, 1);
size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);

size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);

int gb = gg * jcp.nb_ch_blocking;
int g = gb * group_block;

int id_s = -jcp.f_pad + od_s * jcp.stride_d;

int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
int ow_s = owb * jcp.ow_block;
int iw_s = ow_s * jcp.stride_w;

auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0;
int32_t *compensation_w = (jcp.signed_input || jcp.with_input_zp) ? compensation + g : 0;

auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g, od_s, oh_s, ow_s);
auto src_w = src + src_d.blk_off(n, g, id_s, ih_s, iw_s);
auto wht_w = weights + wht_blk_off(weights_d, gb, 0);

auto scales = &oscales[jcp.is_oc_scale * g];

int dilate_d = jcp.dilate_d + 1;
int i_f_overflow = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d));
int i_back_overflow = nstl::min(jcp.kd,
div_up(max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
dilate_d));
int kd_padding = nstl::max(0, jcp.kd - i_f_overflow - i_back_overflow);

size_t wei_d_stride = (jcp.signed_input || jcp.with_input_zp) ? 0 : i_f_overflow * wht_d_stride;

int dilate_h = jcp.dilate_h + 1;
int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h));
int i_b_overflow = nstl::min(jcp.kh,
div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
dilate_h));
int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);

size_t wei_h_stride = (jcp.signed_input || jcp.with_input_zp) ? 0 : i_t_overflow * wht_h_stride;
p.src = src_w + i_t_overflow * dilate_h * src_h_stride
+ i_f_overflow * dilate_d * src_d_stride;
p.dst = dst_w;
p.filt = wht_w + wei_d_stride + wei_h_stride;
p.bias = bias_w;
p.compensation = compensation_w;
p.oc_blocks = gb;
p.kd_padding = kd_padding;
p.kh_padding = kh_padding;
p.scales = scales;
p.f_overflow = i_f_overflow;
p.back_overflow = i_back_overflow;
p.t_overflow = i_t_overflow;
p.b_overflow = i_b_overflow;
p.owb = owb;

p.oc_off = g * sizeof(float);
if (jcp.with_input_zp)
p.input_zp = input_zp + g;

(*kernel_)(&p);
});
return status::success;
}

} // namespace x64
} // namespace cpu
} // namespace impl
Expand Down
9 changes: 7 additions & 2 deletions src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t {
return execute_forward_2d_dw(ctx);
else
return execute_forward_2d(ctx);
else if (_pd->ndims() == 5)
return execute_forward_3d(ctx);
else if (_pd->ndims() == 5) {
if (_pd->jcp_.is_depthwise)
return execute_forward_3d_dw(ctx);
else
return execute_forward_3d(ctx);
}
return status::unimplemented;
}

Expand All @@ -124,6 +128,7 @@ struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t {
status_t execute_forward_2d(const exec_ctx_t &ctx) const;
status_t execute_forward_2d_dw(const exec_ctx_t &ctx) const;
status_t execute_forward_3d(const exec_ctx_t &ctx) const;
status_t execute_forward_3d_dw(const exec_ctx_t &ctx) const;
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }

std::unique_ptr<jit_avx512_core_x8s8s32x_fwd_kernel> kernel_;
Expand Down
13 changes: 7 additions & 6 deletions src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void pick_loop_order(jit_conv_conf_t &jcp) {
jcp.loop_order = loop_cwgn;
if (jcp.ngroups > 1) {
jcp.loop_order = loop_ngcw;
if (jcp.mb < jcp.nthr)
if (jcp.mb < jcp.nthr && jcp.ndims != 5)
jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
} else if (jcp.mb >= jcp.nthr && jcp.ic_without_padding <= 8) {
jcp.loop_order = loop_ngcw;
Expand Down Expand Up @@ -413,7 +413,7 @@ void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::compute_ker_dw(int ur_w, int pad_l,
};

auto kernel_offset = [=](int ci, int ki) {
return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
return jcp.typesize_in * ((ci * jcp.kd * jcp.kh * jcp.kw + ki) * jcp.ch_block);
};

auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
Expand Down Expand Up @@ -1339,8 +1339,6 @@ status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.is_fused_conv)
return status::unimplemented;

if (is_3d && jcp.is_depthwise) return status::unimplemented;

jcp.with_input_zp = !attr.input_zero_points_.has_default_values();
jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values();

Expand Down Expand Up @@ -1410,7 +1408,8 @@ status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i
: OIhw2i8o4i;
} else {
wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i;
wei_tag = with_groups ? jcp.is_depthwise ? Goidhw8g : gOIdhw2i8o4i
: OIdhw2i8o4i;
}
} else {
if (is_avx2) {
Expand All @@ -1425,7 +1424,9 @@ status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
? jcp.is_depthwise ? Goihw4g : gOIhw4o4i
: OIhw4o4i;
} else {
wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i;
wei_tag = with_groups
? jcp.is_depthwise ? Goidhw4g : gOIdhw4o4i
: OIdhw4o4i;
}
}
}
Expand Down
Loading

0 comments on commit 10fcf5d

Please sign in to comment.