Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
OccupyMars2025 authored May 4, 2022
1 parent fb69c58 commit 34aea04
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
51 changes: 51 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,57 @@ void HuberLossInferMeta(const MetaTensor& input,
out->share_lod(input);
}

void IndexFillInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
float fill_value,
MetaTensor* output) {
auto input_dim = x.dims();
auto index_dim = index.dims();

PADDLE_ENFORCE_EQ(
axis < input_dim.size() && axis >= (0 - input_dim.size()),
true,
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(),
input_dim.size() - 1,
axis));

PADDLE_ENFORCE_EQ(
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
true,
phi::errors::InvalidArgument(
"The 'shape' of Input(Index) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
index_dim,
index_dim.size()));

PADDLE_ENFORCE_EQ(
index_dim[0] != 0,
true,
phi::errors::InvalidArgument("The length of Input(Index) can't be 0."));

output->set_dims(x.dims());
output->set_dtype(x.dtype());
output->set_layout(x.layout());
output->share_lod(x);
}

void IndexFillGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& index,
int axis,
float fill_value,
MetaTensor* x_grad) {
auto do_dims = out_grad.dims();
x_grad->set_dims(do_dims);
x_grad->set_dtype(out_grad.dtype());
x_grad->set_layout(out_grad.layout());
x_grad->share_lod(out_grad);
}

void IndexSampleInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ void HuberLossInferMeta(const MetaTensor& input_meta,
MetaTensor* residual,
MetaConfig config = MetaConfig());

void IndexFillInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
float fill_value,
MetaTensor* output);

void IndexFillGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& index,
int axis,
float fill_value,
MetaTensor* x_grad);

void IndexSampleInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
Expand Down

0 comments on commit 34aea04

Please sign in to comment.