-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hackathon No.5] tril_indices OP #41639
Merged
Merged
Changes from 44 commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
f505bcb
add tril_indices cpu kernal
xiaoguoguo626807 6b093dd
modify tril_indice cpu op
xiaoguoguo626807 ba550c2
Merge pull request #1 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 c5c9905
modify bug
xiaoguoguo626807 4b786ad
modify bug
xiaoguoguo626807 146eb8c
add tril_indices python api
xiaoguoguo626807 1925605
add tril_indices python api
xiaoguoguo626807 7319cb8
Merge pull request #2 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 c98e6a8
resolve conflict
xiaoguoguo626807 1a3215f
Merge pull request #3 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 e433a69
add tril_indices test
xiaoguoguo626807 1ff4931
Merge pull request #4 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 b8bf344
modify details
xiaoguoguo626807 f330fb7
add tril_indices.cu
xiaoguoguo626807 e393955
Merge pull request #5 from xiaoguoguo626807/xgg-branch
xiaoguoguo626807 284206c
pythonapi pass
xiaoguoguo626807 68c35c8
Merge branch 'PaddlePaddle:develop' into develop
xiaoguoguo626807 cf001b1
save tril_indices
xiaoguoguo626807 b2b3f0e
solve conflict
xiaoguoguo626807 563ed16
CPU tril_indices pass
xiaoguoguo626807 4c6e429
delete vlog
xiaoguoguo626807 d071526
Merge pull request #6 from xiaoguoguo626807/pythonapi_pass
xiaoguoguo626807 2b9a5c1
modify test_tril_indices_op.py
xiaoguoguo626807 65b67bc
solve conflict
xiaoguoguo626807 be49653
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 8810da3
delete tril_indices_kernel.cc.swp
xiaoguoguo626807 b59c4de
delete tril_indice.cu
xiaoguoguo626807 fbca2b1
modify code style
xiaoguoguo626807 e7df5e5
add newline in creation.py
xiaoguoguo626807 3acde8e
modify creation.py linux newline
xiaoguoguo626807 9ae9091
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 439aa23
delete annotation
xiaoguoguo626807 79a5cfd
check code style
xiaoguoguo626807 110d2b1
check .py style add final_state??
xiaoguoguo626807 b701aa8
modify code style
xiaoguoguo626807 726aedc
Merge branch 'PaddlePaddle:develop' into develop
xiaoguoguo626807 410ffef
Merge branch 'PaddlePaddle:develop' into develop
xiaoguoguo626807 bc8af08
add gpu_tril_indices
xiaoguoguo626807 bef2a98
Merge branch 'develop' of https://github.com/xiaoguoguo626807/Paddle …
xiaoguoguo626807 bbc167f
modify gpu_compiled_juage
xiaoguoguo626807 2a6addc
modify gpu judge
xiaoguoguo626807 79d98f9
code style
xiaoguoguo626807 4219763
add test example
xiaoguoguo626807 cf760e6
modify english document
xiaoguoguo626807 ad8fc4a
modify pram name
xiaoguoguo626807 1ee2d9e
modify pram name
xiaoguoguo626807 1670abc
modify pram
xiaoguoguo626807 e98f12f
fix conflict
xiaoguoguo626807 92e7de4
reduce test ex
xiaoguoguo626807 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/nullary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class TrilIndicesOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), | ||
ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class TrilIndicesOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddOutput("out", | ||
"Tensor, the output tensor, with the shape (2,x),x bounded by " | ||
"[0,rows*cols])"); | ||
AddAttr<int>("rows", | ||
"int number, the input of tril_indices op" | ||
"which describes the number of row of the matrix") | ||
.SetDefault(0); | ||
AddAttr<int>("cols", | ||
"int number, the input of tril_indices op" | ||
"which describes the number of col of the matrix") | ||
.SetDefault(0); | ||
AddAttr<int>( | ||
"offset", | ||
"int number, the input of tril_indices op bounded by [1-rows,cols-1" | ||
"which describes the dignalline index of the lower triangular part of " | ||
"the matrix") | ||
.SetDefault(0); | ||
AddAttr<int>("dtype", "data type ,the input of tril_indices op") | ||
.SetDefault(framework::proto::VarType::INT64); | ||
|
||
AddComment(R"DOC( | ||
TrilIndices Operator. | ||
|
||
The tril_indices operator returns the indices of the lower triangular part of the matrix | ||
whose rows and cols is knowed. It is a 2-by-x tensor,where the first row contains row coordinates | ||
of all indices and the second row contains column coordinates. Indices are ordered based on | ||
rows and then columns. The lower triangular part of the matrix is defined as the elements on | ||
and below the diagonal. | ||
|
||
The argument offset controls which diagonal to consider, default value is 0. | ||
A positive valueincludes just as many diagonals above the main diagonal, | ||
and similarly a negative value excludes just as many diagonals below the main diagonal | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(tril_indices, TrilIndicesInferShapeFunctor, | ||
PD_INFER_META(phi::TrilIndicesInferMeta)); | ||
|
||
REGISTER_OPERATOR( | ||
tril_indices, ops::TrilIndicesOp, ops::TrilIndicesOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, | ||
TrilIndicesInferShapeFunctor); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,4 +115,27 @@ void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape, | |
out->set_layout(DataLayout::NCHW); | ||
} | ||
|
||
void TrilIndicesInferMeta( | ||
int rows, int cols, int offset, DataType dtype, MetaTensor* out) { | ||
// number of elements in the first row of the tril,bounded by [0, cols] | ||
auto m_first_row = | ||
offset > 0 ? std::min<int64_t>(cols, 1 + offset) : rows + offset > 0; | ||
// number of elements in the last row of the tril, bounded by [0, cols] | ||
auto m_last_row = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use |
||
std::max<int64_t>(0, std::min<int64_t>(cols, rows + offset)); | ||
// number of rows, bounded by [0, rows] | ||
auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(rows, rows + offset)); | ||
auto n_row_trapezoid = (m_last_row - m_first_row + 1); | ||
// calculate # of elements in the top trapezoid | ||
auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1; | ||
// calculate # of elements in the bottom rectangle if there is any | ||
auto diff_row = n_row_all - n_row_trapezoid; | ||
if (diff_row > 0) { | ||
tril_size += diff_row * cols; | ||
} | ||
std::vector<int64_t> tmp = {2, tril_size}; | ||
auto out_dims = phi::make_ddim(tmp); | ||
out->set_dims(out_dims); | ||
out->set_dtype(dtype); | ||
} | ||
} // namespace phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/tril_indices_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
template <typename T, typename Context> | ||
void TrilIndicesKernel(const Context& dev_ctx, | ||
int rows, | ||
int cols, | ||
int offset, | ||
DataType dtype, | ||
DenseTensor* out) { | ||
T* out_data = dev_ctx.template Alloc<T>(out); | ||
auto out_dims = out->dims(); | ||
int64_t tril_size = out_dims[1]; | ||
int64_t i = 0; | ||
T r = std::max<int64_t>(0, -offset), c = 0; | ||
while (i < tril_size) { | ||
out_data[i] = r; | ||
out_data[tril_size + i++] = c; | ||
|
||
// move to the next column and check if (r, c) is still in bound | ||
c += 1; | ||
if (c > r + offset || c >= cols) { | ||
r += 1; | ||
c = 0; | ||
// NOTE: not necessary to check if r is less than row here, because i | ||
// and tril_size provide the guarantee | ||
} | ||
} | ||
} | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
tril_indices, CPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/tril_indices_kernel.h" | ||
|
||
#include <algorithm> | ||
#include <tuple> | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T> | ||
__device__ inline int resolve_root_int(int b, int cX4, int x, int32_t sign) { | ||
int bXb_cX4 = b * b - cX4; | ||
double sr = ::sqrt(static_cast<double>(bXb_cX4)); | ||
T res = ::__double2ll_rd((-b + sign * sr) / 2); | ||
if (bXb_cX4 != static_cast<int>(sr * sr)) { | ||
int llsr = ::__double2ll_rd(sr); | ||
int diff = ::__double2ll_ru( | ||
::sqrt(::fabs(static_cast<double>(bXb_cX4 - llsr * llsr)))); | ||
auto l = res > diff ? res - diff : 0; | ||
auto r = res + diff + 1; | ||
x <<= 1; | ||
while (l + 1 < r) { | ||
auto m = (l + r) >> 1; | ||
if (sign * (b + m) * m > x) { | ||
r = m; | ||
} else { | ||
l = m; | ||
} | ||
} | ||
res = l; | ||
} | ||
return res; | ||
} | ||
|
||
template <typename T> | ||
__device__ inline void get_coordinate_in_tril_trapezoid(int f, | ||
int x, | ||
T* row, | ||
T* col) { | ||
f <<= 1; // all statements use 2f, so only calculate it once here. | ||
auto b = f - 1; | ||
auto cX4 = -(x << 3); // 4 * c = 4 * (-2x) = -8x; | ||
*row = resolve_root_int<T>(b, cX4, x, 1); | ||
*col = x - ((f + *row - 1) * *row >> 1); | ||
} | ||
|
||
template <typename T> | ||
__global__ void tril_indices_kernel(T* out_data, | ||
int row_offset, | ||
int m_first_row, | ||
int col, | ||
int trapezoid_size, | ||
int tril_size) { | ||
int linear_index = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
if (linear_index < tril_size) { | ||
T r, c; | ||
if (linear_index < trapezoid_size) { | ||
// the coordinate is within the top trapezoid | ||
get_coordinate_in_tril_trapezoid<T>(m_first_row, linear_index, &r, &c); | ||
} else { | ||
// the coordinate falls in the bottom rectangle | ||
auto surplus = linear_index - trapezoid_size; | ||
// add the height of trapezoid: m_last_row (col) - m_first_row + 1 | ||
r = surplus / col + col - m_first_row + 1; | ||
c = surplus % col; | ||
} | ||
r += row_offset; | ||
|
||
out_data[linear_index] = r; | ||
out_data[linear_index + tril_size] = c; | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void TrilIndicesKernel(const Context& dev_ctx, | ||
int rows, | ||
int cols, | ||
int offset, | ||
DataType dtype, | ||
DenseTensor* out) { | ||
T* out_data = dev_ctx.template Alloc<T>(out); | ||
auto out_dims = out->dims(); | ||
int tril_size = out_dims[1]; | ||
|
||
if (tril_size > 0) { | ||
auto m_first_row = offset > 0 | ||
? std::min<int>(cols, 1 + offset) | ||
: rows + offset > 0; // the number of first row | ||
auto trapezoid_row_offset = | ||
std::max<int>(0, -offset); // index of the first row who has number | ||
auto rectangle_row_offset = trapezoid_row_offset + cols - m_first_row + | ||
1; // the length of the right-up rest matrix | ||
int rectangle_size = 0; | ||
if (rectangle_row_offset < rows) { | ||
rectangle_size = (rows - rectangle_row_offset) * cols; | ||
} // the rectangle part of lowertriangle matrix | ||
|
||
auto GetBlockGridSize = [&dev_ctx](int size) { | ||
const int block_size = | ||
std::min(size, static_cast<int>(dev_ctx.GetMaxThreadsPerBlock())); | ||
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); | ||
const int max_blocks = | ||
std::max(((max_threads - 1) / block_size + 1), static_cast<int>(1)); | ||
const int grid_size = | ||
std::min(max_blocks, (size + block_size - 1) / block_size); | ||
return std::tuple<int, int>{grid_size, block_size}; | ||
}; | ||
|
||
std::tuple<int, int> block_grid_size = GetBlockGridSize(tril_size); | ||
|
||
tril_indices_kernel<T><<<std::get<0>(block_grid_size), | ||
std::get<1>(block_grid_size), | ||
0, | ||
dev_ctx.stream()>>>(out_data, | ||
trapezoid_row_offset, | ||
m_first_row, | ||
cols, | ||
tril_size - rectangle_size, | ||
tril_size); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
tril_indices, GPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void TrilIndicesKernel(const Context& dev_ctx, | ||
int rows, | ||
int cols, | ||
int offset, | ||
DataType dtype, | ||
DenseTensor* out); | ||
|
||
} // namespace phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use
n_first_row
is better? the form of variable name is unified withn_row_all
andn_row_trapezoid
belowThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done