Skip to content

Commit

Permalink
support nhwc format for kunlun conv/batch_norm (#42195)
Browse files Browse the repository at this point in the history
* support nhwc format for kunlun conv/batch_norm
*test=kunlun

* minor
*test=kunlun
  • Loading branch information
QingshuChen authored Apr 26, 2022
1 parent 5be9b82 commit 88d68c0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 26 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ endif()
# ubuntu and centos: use output by XDNN API team
if(NOT DEFINED XPU_XDNN_BASE_URL)
SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220412")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220425")
else()
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()
Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/operators/batch_norm_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
"But received: the size of input's dimensions is [%d]",
x_dims.size()));

int N, C, H, W, D;
int N = -1, C = -1, H = -1, W = -1, D = -1;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
N = (N == 0) ? 1 : N;
C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W;

const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
Expand Down Expand Up @@ -103,12 +107,6 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
"The batch_norm XPU API return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
} else {
PADDLE_ENFORCE_EQ(
data_layout_str == "NCHW", true,
platform::errors::InvalidArgument(
"The batch_norm_infer 'data_layout' attribute must be NCHW. "
"But recevived 'data_layout' is [%s].",
data_layout_str));
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *variance = ctx.Input<Tensor>("Variance");
const auto *mean_data = mean->data<float>();
Expand Down Expand Up @@ -222,8 +220,12 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
"But received: the size of input's dimensions is [%d]",
x_dims.size()));

int N, C, H, W, D;
int N = -1, C = -1, H = -1, W = -1, D = -1;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
N = (N == 0) ? 1 : N;
C = (C == 0) ? 1 : C;
H = (H == 0) ? 1 : H;
W = (W == 0) ? 1 : W;

const auto *x_data = x->data<T>();
const auto *d_y_data = d_y->data<T>();
Expand Down
49 changes: 32 additions & 17 deletions paddle/fluid/operators/conv_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");

PADDLE_ENFORCE_EQ(data_format == "NHWC" || data_format == "NDHWC", false,
platform::errors::InvalidArgument(
("XPU do support data_format is NCHW in conv op.")));
PADDLE_ENFORCE_EQ(
data_format == "NDHWC", false,
platform::errors::InvalidArgument(
("XPU does not support data_format is NDHWC in conv op.")));

framework::DDim in_data_dims =
phi::slice_ddim(input->dims(), 2, input->dims().size());
Expand All @@ -50,11 +51,18 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);

const int batch_size = static_cast<int>(input->dims()[0]);
const int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
int batch_size = static_cast<int>(input->dims()[0]);
int img_c = static_cast<int>(input->dims()[1]);
int img_h = static_cast<int>(input->dims()[2]);
int img_w = static_cast<int>(input->dims()[3]);
int f = static_cast<int>(filter.dims()[0]);
bool is_nchw = true;
if (data_format == "NHWC") {
img_c = static_cast<int>(input->dims()[3]);
img_h = static_cast<int>(input->dims()[1]);
img_w = static_cast<int>(input->dims()[2]);
is_nchw = false;
}

const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
Expand All @@ -64,7 +72,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
nullptr, nullptr, nullptr, true);
nullptr, nullptr, nullptr, is_nchw);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
Expand Down Expand Up @@ -99,9 +107,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
context.Attr<std::string>("padding_algorithm");

PADDLE_ENFORCE_EQ(
data_format == "NHWC" || data_format == "NDHWC", false,
data_format == "NDHWC", false,
platform::errors::InvalidArgument(
("XPU do support data_format is NCHW in conv grad op.")));
("XPU doesn't support data_format is NDHWC in conv grad op.")));

framework::DDim in_data_dims =
phi::slice_ddim(input->dims(), 2, input->dims().size());
Expand All @@ -111,11 +119,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);

const int batch_size = static_cast<int>(input->dims()[0]);
const int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
int batch_size = static_cast<int>(input->dims()[0]);
int img_c = static_cast<int>(input->dims()[1]);
int img_h = static_cast<int>(input->dims()[2]);
int img_w = static_cast<int>(input->dims()[3]);
int f = static_cast<int>(filter.dims()[0]);
bool is_nchw = true;
if (data_format == "NHWC") {
img_c = static_cast<int>(input->dims()[3]);
img_h = static_cast<int>(input->dims()[1]);
img_w = static_cast<int>(input->dims()[2]);
is_nchw = false;
}

const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
Expand All @@ -136,7 +151,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), input_data, filter_data, output_grad_data,
input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
nullptr, nullptr, true);
nullptr, nullptr, is_nchw);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
Expand Down

0 comments on commit 88d68c0

Please sign in to comment.