Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】11新增 API index_fill #42454

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions paddle/fluid/operators/index_fill_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
iclementine marked this conversation as resolved.
Show resolved Hide resolved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"

namespace paddle {
namespace operators {

class IndexFillOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

class IndexFillOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default input Tensor<float>), "
iclementine marked this conversation as resolved.
Show resolved Hide resolved
"the input feature data of IndexFillOp, dtype should be"
"int32, int64, float16, float32, float64.");
AddInput("Index",
"(Tensor, default 1-d Tensor<int>), "
"the 1-D tensor containing the indices to index, "
"dtype should be int32, int64");
AddAttr<int>("axis",
"(int, default 0), "
"the dimension in which we index.")
iclementine marked this conversation as resolved.
Show resolved Hide resolved
.SetDefault(0);
AddAttr<float>("fill_value",
"(float, default 0.0f) The value to be filled.")
.SetDefault(0.0f);
iclementine marked this conversation as resolved.
Show resolved Hide resolved
AddOutput("Out",
"(Tensor, default Tensor<float>),"
iclementine marked this conversation as resolved.
Show resolved Hide resolved
" the output of IndexFillOp, whose dtype is the same as X.");
AddComment(R"DOC(
IndexFill operator
Fills the elements of the input tensor with value
by selecting the indices in the order given in index.

This operator also supports inplace modification.
)DOC");
}
};

template <typename T>
class IndexFillGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> op) const override {
op->SetType("index_fill_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};

class IndexFillGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};

DECLARE_INPLACE_OP_INFERER(IndexFillInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(IndexFillGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexFillGradNoNeedBufferVarsInferer, "X");
iclementine marked this conversation as resolved.
Show resolved Hide resolved

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(index_fill, IndexFillInferShapeFunctor,
PD_INFER_META(phi::IndexFillInferMeta));

REGISTER_OPERATOR(index_fill, ops::IndexFillOp, ops::IndexFillOpMaker,
ops::IndexFillGradMaker<paddle::framework::OpDesc>,
ops::IndexFillGradMaker<paddle::imperative::OpBase>,
ops::IndexFillInplaceInferer, IndexFillInferShapeFunctor);

DECLARE_INFER_SHAPE_FUNCTOR(index_fill_grad, IndexFillGradInferShapeFunctor,
PD_INFER_META(phi::IndexFillGradInferMeta));

REGISTER_OPERATOR(index_fill_grad, ops::IndexFillGradOp,
ops::IndexFillGradInplaceInferer,
ops::IndexFillGradNoNeedBufferVarsInferer,
IndexFillGradInferShapeFunctor);
51 changes: 51 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,57 @@ void HuberLossInferMeta(const MetaTensor& input,
out->share_lod(input);
}

void IndexFillInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
float fill_value,
MetaTensor* output) {
auto input_dim = x.dims();
auto index_dim = index.dims();

PADDLE_ENFORCE_EQ(
axis < input_dim.size() && axis >= (0 - input_dim.size()),
true,
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
iclementine marked this conversation as resolved.
Show resolved Hide resolved
input_dim.size(),
input_dim.size() - 1,
axis));

PADDLE_ENFORCE_EQ(
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
iclementine marked this conversation as resolved.
Show resolved Hide resolved
true,
phi::errors::InvalidArgument(
"The 'shape' of Input(Index) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
index_dim,
index_dim.size()));

PADDLE_ENFORCE_EQ(
index_dim[0] != 0,
true,
phi::errors::InvalidArgument("The length of Input(Index) can't be 0."));

output->set_dims(x.dims());
output->set_dtype(x.dtype());
output->set_layout(x.layout());
output->share_lod(x);
}

void IndexFillGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& index,
int axis,
float fill_value,
MetaTensor* x_grad) {
auto do_dims = out_grad.dims();
x_grad->set_dims(do_dims);
x_grad->set_dtype(out_grad.dtype());
x_grad->set_layout(out_grad.layout());
x_grad->share_lod(out_grad);
}

void IndexSampleInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ void HuberLossInferMeta(const MetaTensor& input_meta,
MetaTensor* residual,
MetaConfig config = MetaConfig());

void IndexFillInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
float fill_value,
MetaTensor* output);

void IndexFillGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& index,
int axis,
float fill_value,
MetaTensor* x_grad);

void IndexSampleInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
Expand Down
65 changes: 65 additions & 0 deletions paddle/phi/kernels/cpu/index_fill_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/index_fill_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/cpu/index_fill_impl.h"

namespace phi {

template <typename T, typename Context>
void IndexFillGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& index,
int axis,
float fill_value,
DenseTensor* x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (axis < 0) {
axis += out_grad.dims().size();
}
const auto& index_type = index.dtype();

bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));

auto fill_val = static_cast<T>(0);
if (index_type == phi::DataType::INT32) {
IndexFillInner<Context, T, int>(dev_ctx, index, x_grad, axis, fill_val);
} else if (index_type == phi::DataType::INT64) {
IndexFillInner<Context, T, int64_t>(dev_ctx, index, x_grad, axis, fill_val);
}
}

} // namespace phi

PD_REGISTER_KERNEL(index_fill_grad,
CPU,
ALL_LAYOUT,
phi::IndexFillGradKernel,
float,
phi::dtype::float16,
double,
int,
int64_t) {}
85 changes: 85 additions & 0 deletions paddle/phi/kernels/cpu/index_fill_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {
template <typename Context, typename T, typename IndexT = int>
void IndexFillInner(const Context& ctx,
const DenseTensor& index,
DenseTensor* output,
int axis,
T fill_val) {
auto output_dim = output->dims();
auto output_dim_size = output_dim.size();
auto index_size = index.dims()[0];

DenseTensor index_cpu_copy;
if (!paddle::platform::is_cpu_place(index.place())) {
phi::Copy(ctx, index, phi::CPUPlace(), true, &index_cpu_copy);
}
const IndexT* index_data = paddle::platform::is_cpu_place(index.place())
? index.data<IndexT>()
: index_cpu_copy.data<IndexT>();

auto slice_size = 1;
for (auto i = axis + 1; i < output_dim_size; i++) {
slice_size *= output_dim[i];
}

auto outer_nums = 1;
for (auto i = 0; i < axis; i++) {
outer_nums *= output_dim[i];
}

for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE(
index_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
output_dim[axis],
index_data[i]));
PADDLE_ENFORCE_LT(
index_data[i],
output_dim[axis],
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
output_dim[axis],
index_data[i]));
}

output->Resize(phi::make_ddim({outer_nums, output_dim[axis], slice_size}));

auto output_tensor = EigenTensor<T, 3>::From(*output);
auto& place = *ctx.eigen_device();
for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_data[j];
auto output_t = output_tensor.chip(index_value, 1);
output_t.device(place) = output_t.constant(fill_val);
}
output->Resize(output_dim);
}

} // namespace phi
Loading