Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.5】为 Paddle 新增 bucketize #177

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
update: modified part 3 and part 6
PommesPeter committed Jul 11, 2022
commit 8413595cb21b7404c73cd333d17a555b31df8397
200 changes: 7 additions & 193 deletions rfcs/APIs/20220709_api_design_for_bucketize.md
Original file line number Diff line number Diff line change
@@ -46,140 +46,16 @@ PyTorch 中有 `torch.bucketize` 的API,详细参数为 `torch.bucketize(input

实现代码:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only show the code of key logic to avoid taking up too much space.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的 我修改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改完毕,请查收

```cpp
#include <ATen/Dispatch.h>
#include <ATen/Functions.h>
#include <ATen/Parallel.h>
#include <ATen/native/BucketizationUtils.h>
#include <ATen/native/Resize.h>
#include <c10/util/irange.h>

/* Implement a numpy like searchsorted and a TF like bucketize function running on cpu
*
* - torch.searchsorted(sorted_sequence, values, right=False, side='left', out_int32=False, sorter=None)
* sorted_sequence - N*D or 1D (apply to all values) tensor containing sorted sequences in last dimension
* values - N*D tensor or a Scalar (when sorted_sequence is 1D) containing the search values
* right - corresponding to lower bound if False and upper bound if True
* side - (preferred to right) corresponding to lower bound if 'left' and upper bound if 'right'
* out_int32 - the output tensor is int64_t type if False and int(32bit normally) type if True.
* sorter - if provided, sorted_sequence may not be sorted and the sorted order is given by this tensor
*
* - torch.bucketize(values, boundaries, right=False, out_int32=False)
* values - N*D tensor or a Scalar containing the search value
* boundaries - 1D tensor containing a sorted sequences
* right - corresponding to lower bound if False and upper bound if True
* out_int32 - the output tensor is int64_t type if False and int(32bit normally) type if True.
*
* - Restrictions are defined in searchsorted_pre_check()
*/

namespace at {
namespace native {

namespace {

// minimal size for searchsorted_cpu_contiguous to run parallel (multithread)
constexpr int64_t SEARCHSORTED_GRAIN_SIZE = 200;

// customized lower_bound func to ensure the low bound of 'nan', 'inf' etc. be the end of boundary
// and we can properly handle a sorter argument
// std::lower_bound can not be used here since its customized comparator need strict weak ordering
// and the customized comparators require both arguments to have the same type, which wouldn't
// happen when comparing val of input_t to an indexer value from sorter of int64
template<typename input_t>
int64_t cus_lower_bound(int64_t start, int64_t end, const input_t val, const input_t* bd, const int64_t* sort) {
// sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset
// i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2
const int64_t orig_start = start;
while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const input_t mid_val = sort ? bd[sort[mid] + orig_start] : bd[mid];
if (!(mid_val >= val)) {
start = mid + 1;
}
else {
end = mid;
}
}
return start;
}
// ...

// customized upper_bound func to ensure we can properly handle a sorter argument
// std::upper_bound can not be used here since its customized comparator requires both arguments to have the
// same type, which wouldn't happen when comparing val of input_t to an indexer value from sorter of int64
template<typename input_t>
int64_t cus_upper_bound(int64_t start, int64_t end, const input_t val, const input_t* bd, const int64_t* sort) {
// sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset
// i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2
const int64_t orig_start = start;
while (start < end) {
const int64_t mid = start + ((end - start) >> 1);
const input_t mid_val = sort ? bd[sort[mid] + orig_start] : bd[mid];
if (!(mid_val > val)) {
start = mid + 1;
}
else {
end = mid;
}
}
return start;
}

template<typename input_t, typename output_t>
void searchsorted_cpu_contiguous(Tensor& result, const Tensor& input, const Tensor& boundaries, const bool& right, const Tensor& sorter) {
int64_t numel_in = input.numel();
bool is_scalar_input = input.dim() == 0 && numel_in == 1;
// inner most dim size of input and boundaries
int64_t idim_in = is_scalar_input ? 1 : input.sizes().back();
int64_t idim_bd = boundaries.sizes().back();

const input_t *data_in = input.data_ptr<input_t>();
const input_t *data_bd = boundaries.data_ptr<input_t>();
const int64_t *data_st = sorter.defined() ? sorter.data_ptr<int64_t>() : nullptr;
output_t *data_out = result.data_ptr<output_t>();

bool is_1d_boundaries = boundaries.dim() == 1;
at::parallel_for(0, numel_in, SEARCHSORTED_GRAIN_SIZE, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
// If boundaries tensor is 1d, we always search the entire boundary tensor
int64_t start_bd = is_1d_boundaries ? 0 : i / idim_in * idim_bd;
int64_t end_bd = start_bd + idim_bd;

int64_t pos = !right ?
cus_lower_bound(start_bd, end_bd, data_in[i], data_bd, data_st) - start_bd :
cus_upper_bound(start_bd, end_bd, data_in[i], data_bd, data_st) - start_bd;

// type conversion might happen here
data_out[i] = pos;
}
});
}

void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right, const Tensor& sorter) {
if (!out_int32) {
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
input.scalar_type(),
"searchsorted_out_cpu",
[&] {
searchsorted_cpu_contiguous<scalar_t, int64_t>(
result, input, boundaries, right, sorter);
});
}
else {
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
input.scalar_type(),
"searchsorted_out_cpu",
[&] {
searchsorted_cpu_contiguous<scalar_t, int>(
result, input, boundaries, right, sorter);
});
}
}

}
// ...

Tensor& searchsorted_out_cpu(
const Tensor& sorted_sequence,
@@ -189,20 +65,18 @@ Tensor& searchsorted_out_cpu(
const c10::optional<c10::string_view> side_opt,
const c10::optional<Tensor>& sorter_opt,
Tensor& result) {
// See [Note: hacky wrapper removal for optional tensor]

c10::MaybeOwned<Tensor> sorter_maybe_owned = at::borrow_from_optional_tensor(sorter_opt);
const Tensor& sorter = *sorter_maybe_owned;
searchsorted_pre_check(sorted_sequence, self, result, out_int32, right, side_opt, sorter);
resize_output(result, self.sizes());

// we have two inputs to set right, pre_check checks that they aren't set to opposites
bool is_right = side_opt ? *side_opt == "right" : right;

if (self.numel() == 0) {
return result;
}

// for non-contiguous result tensors, we write the output to a contiguous copy so we can later copy back, maintaing the original result tensor
Tensor out = result;
if (!result.is_contiguous()) {
out = result.contiguous();
@@ -221,38 +95,12 @@ Tensor& searchsorted_out_cpu(
dispatch(out, final_input, final_boundaries, out_int32, is_right, final_sorter);
}

// if result is non-contiguous, we wrote the answer to a copied version, so we copy back to the original result tensor
if (!result.is_contiguous()) {
result.copy_(out);
}
return result;
}

Tensor searchsorted_cpu(
const Tensor& sorted_sequence,
const Tensor& self,
bool out_int32,
bool right,
const c10::optional<c10::string_view> side_opt,
const c10::optional<Tensor>& sorter_opt) {
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);
Tensor result = at::empty({0}, options, MemoryFormat::Contiguous);
at::native::searchsorted_out_cpu(sorted_sequence, self, out_int32, right, side_opt, sorter_opt, result);
return result;
}

Tensor searchsorted_cpu(
const Tensor& sorted_sequence,
const Scalar& self,
bool out_int32,
bool right,
const c10::optional<c10::string_view> side_opt,
const c10::optional<Tensor>& sorter_opt) {
const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device());
return searchsorted_cpu(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt);
}

Tensor& bucketize_out_cpu(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) {
TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")");
at::native::searchsorted_out_cpu(boundaries, self, out_int32, right, nullopt, nullopt, result);
@@ -303,42 +151,6 @@ def bucketize(x: common_types.ConsistentTensorType,
weights: Optional[tf.Tensor] = None,
elementwise: bool = False,
name: Optional[str] = None) -> common_types.ConsistentTensorType:
"""Returns a bucketized column, with a bucket index assigned to each input.
Args:
x: A numeric input `Tensor` or `CompositeTensor` whose values should be
mapped to buckets. For a `CompositeTensor` only non-missing values will
be included in the quantiles computation, and the result of `bucketize`
will be a `CompositeTensor` with non-missing values mapped to buckets. If
elementwise=True then `x` must be dense.
num_buckets: Values in the input `x` are divided into approximately
equal-sized buckets, where the number of buckets is `num_buckets`.
epsilon: (Optional) Error tolerance, typically a small fraction close to
zero. If a value is not specified by the caller, a suitable value is
computed based on experimental results. For `num_buckets` less than 100,
the value of 0.01 is chosen to handle a dataset of up to ~1 trillion input
data values. If `num_buckets` is larger, then epsilon is set to
(1/`num_buckets`) to enforce a stricter error tolerance, because more
buckets will result in smaller range for each bucket, and so we want the
boundaries to be less fuzzy. See analyzers.quantiles() for details.
weights: (Optional) Weights tensor for the quantiles. Tensor must have the
same shape as x.
elementwise: (Optional) If true, bucketize each element of the tensor
independently.
name: (Optional) A name for this operation.
Returns:
A `Tensor` of the same shape as `x`, with each element in the
returned tensor representing the bucketized value. Bucketized value is
in the range [0, actual_num_buckets). Sometimes the actual number of buckets
can be different than num_buckets hint, for example in case the number of
distinct values is smaller than num_buckets, or in cases where the
input values are not uniformly distributed.
NaN values are mapped to the last bucket. Values with NaN weights are
ignored in bucket boundaries calculation.
Raises:
TypeError: If num_buckets is not an int.
ValueError: If value of num_buckets is not > 1.
ValueError: If elementwise=True and x is a `CompositeTensor`.
"""
with tf.compat.v1.name_scope(name, 'bucketize'):
if not isinstance(num_buckets, int):
raise TypeError('num_buckets must be an int, got %s' % type(num_buckets))
@@ -433,9 +245,11 @@ paddle.bucketize(

测试需要考虑的 case 如下:

- 数值结果的一致性,使用 numpy 作为参考标准
- 输出数值结果的一致性,使用 numpy 作为参考标准
- 参数 `right` 为 True 和 False 时输出的正确性
- 参数 `out_int32` 为 True 和 False 时 dtype 输出的正确性;
- 参数 `out_int32` 为 True 和 False 时 dtype 输出的正确性
- 参数 `x` 类型的正确性,若类型不为 Tensor 则抛出异常
- 参数 `sorted_sequence` 的维度正确性,该 API 只针对 `sorted_sequence` 是一维的情况,所以对于输入需要约束
- 未输入 `right` 时的输出正确性;
- 未输入 `out_int32` 时的输出正确性;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should consider more testing case, eg, some error case... refer to unit test specification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的 我修改一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改完毕,请查收