Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon No.5] tril_indices OP #41639

Merged
merged 49 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f505bcb
add tril_indices cpu kernal
xiaoguoguo626807 Mar 30, 2022
6b093dd
modify tril_indice cpu op
xiaoguoguo626807 Mar 31, 2022
ba550c2
Merge pull request #1 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 Mar 31, 2022
c5c9905
modify bug
xiaoguoguo626807 Mar 31, 2022
4b786ad
modify bug
xiaoguoguo626807 Apr 1, 2022
146eb8c
add tril_indices python api
xiaoguoguo626807 Apr 1, 2022
1925605
add tril_indices python api
xiaoguoguo626807 Apr 1, 2022
7319cb8
Merge pull request #2 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 Apr 1, 2022
c98e6a8
resolve conflict
xiaoguoguo626807 Apr 1, 2022
1a3215f
Merge pull request #3 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 Apr 1, 2022
e433a69
add tril_indices test
xiaoguoguo626807 Apr 1, 2022
1ff4931
Merge pull request #4 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 Apr 1, 2022
b8bf344
modify details
xiaoguoguo626807 Apr 2, 2022
f330fb7
add tril_indices.cu
xiaoguoguo626807 Apr 6, 2022
e393955
Merge pull request #5 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 Apr 6, 2022
284206c
pythonapi pass
xiaoguoguo626807 Apr 7, 2022
68c35c8
Merge branch 'PaddlePaddle:develop' into develop
xiaoguoguo626807 Apr 7, 2022
cf001b1
save tril_indices
xiaoguoguo626807 Apr 7, 2022
b2b3f0e
solve conflict
xiaoguoguo626807 Apr 7, 2022
563ed16
CPU tril_indices pass
xiaoguoguo626807 Apr 8, 2022
4c6e429
delete vlog
xiaoguoguo626807 Apr 8, 2022
d071526
Merge pull request #6 from xiaoguoguo626807/pythonapi_pass
xiaoguoguo626807 Apr 8, 2022
2b9a5c1
modify test_tril_indices_op.py
xiaoguoguo626807 Apr 11, 2022
65b67bc
solve conflict
xiaoguoguo626807 Apr 11, 2022
be49653
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Apr 11, 2022
8810da3
delete tril_indices_kernel.cc.swp
xiaoguoguo626807 Apr 11, 2022
b59c4de
delete tril_indice.cu
xiaoguoguo626807 Apr 11, 2022
fbca2b1
modify code style
xiaoguoguo626807 Apr 12, 2022
e7df5e5
add newline in creation.py
xiaoguoguo626807 Apr 13, 2022
3acde8e
modify creation.py linux newline
xiaoguoguo626807 Apr 13, 2022
9ae9091
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Apr 13, 2022
439aa23
delete annotation
xiaoguoguo626807 Apr 14, 2022
79a5cfd
check code style
xiaoguoguo626807 Apr 15, 2022
110d2b1
check .py style add final_state??
xiaoguoguo626807 Apr 15, 2022
b701aa8
modify code style
xiaoguoguo626807 Apr 19, 2022
726aedc
Merge branch 'PaddlePaddle:develop' into develop
xiaoguoguo626807 Apr 19, 2022
410ffef
Merge branch 'PaddlePaddle:develop' into develop
xiaoguoguo626807 Apr 25, 2022
bc8af08
add gpu_tril_indices
xiaoguoguo626807 Apr 25, 2022
bef2a98
Merge branch 'develop' of https://github.com/xiaoguoguo626807/Paddle …
xiaoguoguo626807 Apr 25, 2022
bbc167f
modify gpu_compiled_juage
xiaoguoguo626807 Apr 25, 2022
2a6addc
modify gpu judge
xiaoguoguo626807 Apr 25, 2022
79d98f9
code style
xiaoguoguo626807 Apr 26, 2022
4219763
add test example
xiaoguoguo626807 Apr 27, 2022
cf760e6
modify english document
xiaoguoguo626807 Apr 27, 2022
ad8fc4a
modify pram name
xiaoguoguo626807 May 9, 2022
1ee2d9e
modify pram name
xiaoguoguo626807 May 9, 2022
1670abc
modify pram
xiaoguoguo626807 May 10, 2022
e98f12f
fix conflict
xiaoguoguo626807 May 11, 2022
92e7de4
reduce test ex
xiaoguoguo626807 May 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions paddle/fluid/operators/tril_indices_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h"

namespace paddle {
namespace operators {

class TrilIndicesOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};

class TrilIndicesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("out",
"Tensor, the output tensor, with the shape (2,x),x bounded by "
"[0,rows*cols])");
AddAttr<int>("rows",
"int number, the input of tril_indices op"
"which describes the number of row of the matrix")
.SetDefault(0);
AddAttr<int>("cols",
"int number, the input of tril_indices op"
"which describes the number of col of the matrix")
.SetDefault(0);
AddAttr<int>(
"offset",
"int number, the input of tril_indices op bounded by [1-rows,cols-1"
"which describes the dignalline index of the lower triangular part of "
"the matrix")
.SetDefault(0);
AddAttr<int>("dtype", "data type ,the input of tril_indices op")
.SetDefault(framework::proto::VarType::INT64);

AddComment(R"DOC(
TrilIndices Operator.

The tril_indices operator returns the indices of the lower triangular part of the matrix
whose rows and cols is knowed. It is a 2-by-x tensor,where the first row contains row coordinates
of all indices and the second row contains column coordinates. Indices are ordered based on
rows and then columns. The lower triangular part of the matrix is defined as the elements on
and below the diagonal.

The argument offset controls which diagonal to consider, default value is 0.
A positive valueincludes just as many diagonals above the main diagonal,
and similarly a negative value excludes just as many diagonals below the main diagonal
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(tril_indices, TrilIndicesInferShapeFunctor,
PD_INFER_META(phi::TrilIndicesInferMeta));

REGISTER_OPERATOR(
tril_indices, ops::TrilIndicesOp, ops::TrilIndicesOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
TrilIndicesInferShapeFunctor);
23 changes: 23 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,27 @@ void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
out->set_layout(DataLayout::NCHW);
}

void TrilIndicesInferMeta(
int rows, int cols, int offset, DataType dtype, MetaTensor* out) {
// number of elements in the first row of the tril,bounded by [0, cols]
auto n_first_row =
offset > 0 ? std::min<int64_t>(cols, 1 + offset) : rows + offset > 0;
// number of elements in the last row of the tril, bounded by [0, cols]
auto n_last_row =
std::max<int64_t>(0, std::min<int64_t>(cols, rows + offset));
// number of rows, bounded by [0, rows]
auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(rows, rows + offset));
auto n_row_trapezoid = (n_last_row - n_first_row + 1);
// calculate # of elements in the top trapezoid
auto tril_size = (n_first_row + n_last_row) * n_row_trapezoid >> 1;
// calculate # of elements in the bottom rectangle if there is any
auto diff_row = n_row_all - n_row_trapezoid;
if (diff_row > 0) {
tril_size += diff_row * cols;
}
std::vector<int64_t> tmp = {2, tril_size};
auto out_dims = phi::make_ddim(tmp);
out->set_dims(out_dims);
out->set_dtype(dtype);
}
} // namespace phi
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,6 @@ void UniformRandomInferMeta(const IntArray& shape,
int seed,
MetaTensor* out);

void TrilIndicesInferMeta(
int rows, int cols, int offset, DataType dtype, MetaTensor* out);
} // namespace phi
50 changes: 50 additions & 0 deletions paddle/phi/kernels/cpu/tril_indices_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/tril_indices_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
template <typename T, typename Context>
void TrilIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto out_dims = out->dims();
int64_t tril_size = out_dims[1];
int64_t i = 0;
T r = std::max<int64_t>(0, -offset), c = 0;
while (i < tril_size) {
out_data[i] = r;
out_data[tril_size + i++] = c;

// move to the next column and check if (r, c) is still in bound
c += 1;
if (c > r + offset || c >= cols) {
r += 1;
c = 0;
// NOTE: not necessary to check if r is less than row here, because i
// and tril_size provide the guarantee
}
}
}
} // namespace phi

PD_REGISTER_KERNEL(
tril_indices, CPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {}
142 changes: 142 additions & 0 deletions paddle/phi/kernels/gpu/tril_indices_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/tril_indices_kernel.h"

#include <algorithm>
#include <tuple>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T>
__device__ inline int resolve_root_int(int b, int cX4, int x, int32_t sign) {
int bXb_cX4 = b * b - cX4;
double sr = ::sqrt(static_cast<double>(bXb_cX4));
T res = ::__double2ll_rd((-b + sign * sr) / 2);
if (bXb_cX4 != static_cast<int>(sr * sr)) {
int llsr = ::__double2ll_rd(sr);
int diff = ::__double2ll_ru(
::sqrt(::fabs(static_cast<double>(bXb_cX4 - llsr * llsr))));
auto l = res > diff ? res - diff : 0;
auto r = res + diff + 1;
x <<= 1;
while (l + 1 < r) {
auto m = (l + r) >> 1;
if (sign * (b + m) * m > x) {
r = m;
} else {
l = m;
}
}
res = l;
}
return res;
}

template <typename T>
__device__ inline void get_coordinate_in_tril_trapezoid(int f,
int x,
T* row,
T* col) {
f <<= 1; // all statements use 2f, so only calculate it once here.
auto b = f - 1;
auto cX4 = -(x << 3); // 4 * c = 4 * (-2x) = -8x;
*row = resolve_root_int<T>(b, cX4, x, 1);
*col = x - ((f + *row - 1) * *row >> 1);
}

template <typename T>
__global__ void tril_indices_kernel(T* out_data,
int row_offset,
int m_first_row,
int col,
int trapezoid_size,
int tril_size) {
int linear_index = blockIdx.x * blockDim.x + threadIdx.x;

if (linear_index < tril_size) {
T r, c;
if (linear_index < trapezoid_size) {
// the coordinate is within the top trapezoid
get_coordinate_in_tril_trapezoid<T>(m_first_row, linear_index, &r, &c);
} else {
// the coordinate falls in the bottom rectangle
auto surplus = linear_index - trapezoid_size;
// add the height of trapezoid: m_last_row (col) - m_first_row + 1
r = surplus / col + col - m_first_row + 1;
c = surplus % col;
}
r += row_offset;

out_data[linear_index] = r;
out_data[linear_index + tril_size] = c;
}
}

template <typename T, typename Context>
void TrilIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto out_dims = out->dims();
int tril_size = out_dims[1];

if (tril_size > 0) {
auto m_first_row = offset > 0
? std::min<int>(cols, 1 + offset)
: rows + offset > 0; // the number of first row
auto trapezoid_row_offset =
std::max<int>(0, -offset); // index of the first row who has number
auto rectangle_row_offset = trapezoid_row_offset + cols - m_first_row +
1; // the length of the right-up rest matrix
int rectangle_size = 0;
if (rectangle_row_offset < rows) {
rectangle_size = (rows - rectangle_row_offset) * cols;
} // the rectangle part of lowertriangle matrix

auto GetBlockGridSize = [&dev_ctx](int size) {
const int block_size =
std::min(size, static_cast<int>(dev_ctx.GetMaxThreadsPerBlock()));
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int>(1));
const int grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
return std::tuple<int, int>{grid_size, block_size};
};

std::tuple<int, int> block_grid_size = GetBlockGridSize(tril_size);

tril_indices_kernel<T><<<std::get<0>(block_grid_size),
std::get<1>(block_grid_size),
0,
dev_ctx.stream()>>>(out_data,
trapezoid_row_offset,
m_first_row,
cols,
tril_size - rectangle_size,
tril_size);
}
}

} // namespace phi

PD_REGISTER_KERNEL(
tril_indices, GPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {}
29 changes: 29 additions & 0 deletions paddle/phi/kernels/tril_indices_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void TrilIndicesKernel(const Context& dev_ctx,
int rows,
int cols,
int offset,
DataType dtype,
DenseTensor* out);

} // namespace phi
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .tensor.creation import assign # noqa: F401
from .tensor.creation import complex # noqa: F401
from .tensor.creation import clone # noqa: F401
from .tensor.creation import tril_indices #noqa: F401
from .tensor.linalg import matmul # noqa: F401
from .tensor.linalg import dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
Expand Down Expand Up @@ -637,4 +638,5 @@
'take_along_axis',
'put_along_axis',
'heaviside',
'tril_indices',
]
Loading