Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.15】add RFC for Nanmedian #89

Merged
merged 5 commits into from
Apr 13, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add c++ design
thunder95 committed Apr 8, 2022
commit e737fd90fb05a2c170528ba63f694faf0206d29c
41 changes: 33 additions & 8 deletions rfcs/APIs/20220331_api_design_for_nanmedian.md
Original file line number Diff line number Diff line change
@@ -205,23 +205,48 @@ API设计为`paddle.nanmedian(x, axis=None, keepdim=False, name=None)`
参数类型中,`axis`支持`int|tuple|list`输入, keepdim支持返回保持原来的形状。

## 底层OP设计
基于已有API组合实现,不再单独设计OP。
现有API对NAN元素支持比较有限, 需要单独设计一个OP, 支持cpu和cuda, 对最后一维度进行nanmedian算子操作.
参考Pytorch不做反向传播梯度计算。

## API实现方案
主要按下列步骤进行组合实现,实现位置为`paddle/tensor/math.py``sum`,`nansum`等方法放在一起:
1. 多个axis时,先计算出需要转置的axis序列,将目标axis数据元素转置到最后
2. 使用`paddle.transpose`获取axis上的元素
3. 使用`paddle.isnan`以及`paddle.where`得到输入Tensor的nan mask,以及指定轴的非nan值的计数值cnt.
4. 使用`paddle.sort`得到忽略nan的输入张量的排序。
5. 计算已排序张量上中位数索引值,根据总长的奇偶提取中位数的值
1. 如果axis是多个轴, 参考`paddle.quantile`算子对输入进行轴的转换操作
2. 如果axis是一个轴, 当不是最后一个轴, 先使用`paddle.transpose`获取axis上的目标元素
3. 设计cpu和gpu核函数, 对最后的轴进行nanmedian计算。
4. 将最终结果reshape到目标维度
5. `keepdim`参数的处理,对标Numpy融合到各个API当中

-`keepdim`参数的处理,对标Numpy融合到各个API当中。
## 代码实现文件路径

在文件paddle/phi/kernels/impl/nanmedian_kernel_impl.h和paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h中编写主要的正向计算逻辑:

CPU中正向计算逻辑: paddle/phi/kernels/cpu/nanmedian_kernel.cc

GPU中正向计算逻辑: paddle/phi/kernels/gpu/nanmedian_funcs.h paddle/phi/kernels/gpu/nanmedian_kernel.cu



```c++
template <typename T, typename Context>
void NanMedianKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);

```
算子注册路径:
paddle/fluid/operators/nanmedian_op.cc
函数API实现路径: python/paddle/tensor/stat.py
单元测试路径: python/paddle/fluid/tests/unittests/test_nanmedian.py
# 六、测试和验收的考量
测试考虑的case如下:
- 和numpy结果的数值的一致性, `paddle.nanmedian`,和`np.nanmdian`结果是否一致;
- 参数`axis`校验参数类型int,tuple以及list,判断axis合法,并进行边界检查;
- axis测试需覆盖None, int, tuple, list, 数量需要覆盖一维,多维,以及全部维度;
- `keepdim`参数的正确性,输出结果的正确性;
- 输入含`NaN`结果的正确性;
- 输入所有轴上都不含`NaN`结果的正确性;
@@ -233,7 +258,7 @@ API设计为`paddle.nanmedian(x, axis=None, keepdim=False, name=None)`
# 七、可行性分析及规划排期
方案主要依赖现有paddle api组合而成,工期上可以满足在当前版本周期内开发完成。
方案实施难度可控,工期上可以满足在当前版本周期内开发完成。
# 八、影响面
为独立新增API,对其他模块没有影响