-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Hackathon No.5] tril_indices OP (#41639)
* add tril_indices cpu kernal * modify tril_indice cpu op * modify bug * modify bug * add tril_indices python api * add tril_indices python api * resolve conflict * add tril_indices test * modify details * add tril_indices.cu * pythonapi pass * save tril_indices * CPU tril_indices pass * delete vlog * modify test_tril_indices_op.py * delete tril_indices_kernel.cc.swp * delete tril_indice.cu * modify code style * add newline in creation.py * modify creation.py linux newline * delete annotation * check code style * check .py style add final_state?? * modify code style * add gpu_tril_indices * modify gpu_compiled_juage * modify gpu judge * code style * add test example * modify english document modify english document modify english document modify document modify document * modify pram name * modify pram name * modify pram * reduce test ex
- Loading branch information
1 parent
1f76eab
commit 75db5b8
Showing
10 changed files
with
561 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/nullary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class TrilIndicesOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), | ||
ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class TrilIndicesOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddOutput("out", | ||
"Tensor, the output tensor, with the shape (2,x),x bounded by " | ||
"[0,rows*cols])"); | ||
AddAttr<int>("rows", | ||
"int number, the input of tril_indices op" | ||
"which describes the number of row of the matrix") | ||
.SetDefault(0); | ||
AddAttr<int>("cols", | ||
"int number, the input of tril_indices op" | ||
"which describes the number of col of the matrix") | ||
.SetDefault(0); | ||
AddAttr<int>( | ||
"offset", | ||
"int number, the input of tril_indices op bounded by [1-rows,cols-1" | ||
"which describes the dignalline index of the lower triangular part of " | ||
"the matrix") | ||
.SetDefault(0); | ||
AddAttr<int>("dtype", "data type ,the input of tril_indices op") | ||
.SetDefault(framework::proto::VarType::INT64); | ||
|
||
AddComment(R"DOC( | ||
TrilIndices Operator. | ||
The tril_indices operator returns the indices of the lower triangular part of the matrix | ||
whose rows and cols is knowed. It is a 2-by-x tensor,where the first row contains row coordinates | ||
of all indices and the second row contains column coordinates. Indices are ordered based on | ||
rows and then columns. The lower triangular part of the matrix is defined as the elements on | ||
and below the diagonal. | ||
The argument offset controls which diagonal to consider, default value is 0. | ||
A positive valueincludes just as many diagonals above the main diagonal, | ||
and similarly a negative value excludes just as many diagonals below the main diagonal | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(tril_indices, TrilIndicesInferShapeFunctor, | ||
PD_INFER_META(phi::TrilIndicesInferMeta)); | ||
|
||
REGISTER_OPERATOR( | ||
tril_indices, ops::TrilIndicesOp, ops::TrilIndicesOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, | ||
TrilIndicesInferShapeFunctor); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/tril_indices_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
template <typename T, typename Context> | ||
void TrilIndicesKernel(const Context& dev_ctx, | ||
int rows, | ||
int cols, | ||
int offset, | ||
DataType dtype, | ||
DenseTensor* out) { | ||
T* out_data = dev_ctx.template Alloc<T>(out); | ||
auto out_dims = out->dims(); | ||
int64_t tril_size = out_dims[1]; | ||
int64_t i = 0; | ||
T r = std::max<int64_t>(0, -offset), c = 0; | ||
while (i < tril_size) { | ||
out_data[i] = r; | ||
out_data[tril_size + i++] = c; | ||
|
||
// move to the next column and check if (r, c) is still in bound | ||
c += 1; | ||
if (c > r + offset || c >= cols) { | ||
r += 1; | ||
c = 0; | ||
// NOTE: not necessary to check if r is less than row here, because i | ||
// and tril_size provide the guarantee | ||
} | ||
} | ||
} | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
tril_indices, CPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/tril_indices_kernel.h" | ||
|
||
#include <algorithm> | ||
#include <tuple> | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T> | ||
__device__ inline int resolve_root_int(int b, int cX4, int x, int32_t sign) { | ||
int bXb_cX4 = b * b - cX4; | ||
double sr = ::sqrt(static_cast<double>(bXb_cX4)); | ||
T res = ::__double2ll_rd((-b + sign * sr) / 2); | ||
if (bXb_cX4 != static_cast<int>(sr * sr)) { | ||
int llsr = ::__double2ll_rd(sr); | ||
int diff = ::__double2ll_ru( | ||
::sqrt(::fabs(static_cast<double>(bXb_cX4 - llsr * llsr)))); | ||
auto l = res > diff ? res - diff : 0; | ||
auto r = res + diff + 1; | ||
x <<= 1; | ||
while (l + 1 < r) { | ||
auto m = (l + r) >> 1; | ||
if (sign * (b + m) * m > x) { | ||
r = m; | ||
} else { | ||
l = m; | ||
} | ||
} | ||
res = l; | ||
} | ||
return res; | ||
} | ||
|
||
template <typename T> | ||
__device__ inline void get_coordinate_in_tril_trapezoid(int f, | ||
int x, | ||
T* row, | ||
T* col) { | ||
f <<= 1; // all statements use 2f, so only calculate it once here. | ||
auto b = f - 1; | ||
auto cX4 = -(x << 3); // 4 * c = 4 * (-2x) = -8x; | ||
*row = resolve_root_int<T>(b, cX4, x, 1); | ||
*col = x - ((f + *row - 1) * *row >> 1); | ||
} | ||
|
||
template <typename T> | ||
__global__ void tril_indices_kernel(T* out_data, | ||
int row_offset, | ||
int m_first_row, | ||
int col, | ||
int trapezoid_size, | ||
int tril_size) { | ||
int linear_index = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
if (linear_index < tril_size) { | ||
T r, c; | ||
if (linear_index < trapezoid_size) { | ||
// the coordinate is within the top trapezoid | ||
get_coordinate_in_tril_trapezoid<T>(m_first_row, linear_index, &r, &c); | ||
} else { | ||
// the coordinate falls in the bottom rectangle | ||
auto surplus = linear_index - trapezoid_size; | ||
// add the height of trapezoid: m_last_row (col) - m_first_row + 1 | ||
r = surplus / col + col - m_first_row + 1; | ||
c = surplus % col; | ||
} | ||
r += row_offset; | ||
|
||
out_data[linear_index] = r; | ||
out_data[linear_index + tril_size] = c; | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void TrilIndicesKernel(const Context& dev_ctx, | ||
int rows, | ||
int cols, | ||
int offset, | ||
DataType dtype, | ||
DenseTensor* out) { | ||
T* out_data = dev_ctx.template Alloc<T>(out); | ||
auto out_dims = out->dims(); | ||
int tril_size = out_dims[1]; | ||
|
||
if (tril_size > 0) { | ||
auto m_first_row = offset > 0 | ||
? std::min<int>(cols, 1 + offset) | ||
: rows + offset > 0; // the number of first row | ||
auto trapezoid_row_offset = | ||
std::max<int>(0, -offset); // index of the first row who has number | ||
auto rectangle_row_offset = trapezoid_row_offset + cols - m_first_row + | ||
1; // the length of the right-up rest matrix | ||
int rectangle_size = 0; | ||
if (rectangle_row_offset < rows) { | ||
rectangle_size = (rows - rectangle_row_offset) * cols; | ||
} // the rectangle part of lowertriangle matrix | ||
|
||
auto GetBlockGridSize = [&dev_ctx](int size) { | ||
const int block_size = | ||
std::min(size, static_cast<int>(dev_ctx.GetMaxThreadsPerBlock())); | ||
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); | ||
const int max_blocks = | ||
std::max(((max_threads - 1) / block_size + 1), static_cast<int>(1)); | ||
const int grid_size = | ||
std::min(max_blocks, (size + block_size - 1) / block_size); | ||
return std::tuple<int, int>{grid_size, block_size}; | ||
}; | ||
|
||
std::tuple<int, int> block_grid_size = GetBlockGridSize(tril_size); | ||
|
||
tril_indices_kernel<T><<<std::get<0>(block_grid_size), | ||
std::get<1>(block_grid_size), | ||
0, | ||
dev_ctx.stream()>>>(out_data, | ||
trapezoid_row_offset, | ||
m_first_row, | ||
cols, | ||
tril_size - rectangle_size, | ||
tril_size); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
tril_indices, GPU, ALL_LAYOUT, phi::TrilIndicesKernel, int, int64_t) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void TrilIndicesKernel(const Context& dev_ctx, | ||
int rows, | ||
int cols, | ||
int offset, | ||
DataType dtype, | ||
DenseTensor* out); | ||
|
||
} // namespace phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.