Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.30】 #34

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update api_design_for_tripletmargindistanceloss.md
yangguohao authored Mar 22, 2022
commit ae1e7fb4bb91a8a84e167002678425d9b94f3cc9
9 changes: 4 additions & 5 deletions rfcs/APIs/api_design_for_tripletmargindistanceloss.md
Original file line number Diff line number Diff line change
@@ -159,16 +159,16 @@ def triplet_loss(queries, positives, negatives, margin=0.1):
- `padde.nn.functional.triplet_margin_with_distance_loss(input, positive, negative, distance_function=None, margin=1.0, swap=False, reduction='mean', name=None) -> Tensor`
## 底层OP设计
## API实现方案
distance functions可以采用paddle.nn.PairWiseDistance来进行实现

1. 检查参数

1. 检查 reduction 有效性(同其余 functional loss 中的实现)
2. 检查输入的 dtype(含 `input``positive``negative`)(同其余 functional loss 中的实现)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入损失的不一定为(N,dim),需要检查维度之后进行维度转换

3. 检查输入的`input``positive``negative`维度是否相同
3. 用reshape方法进行转换为维度,[batch_size,dim],并检查参数维度是否相同。

shiyutang marked this conversation as resolved.
Show resolved Hide resolved
2. 计算

1. 用户可传入distance_function参数,如果未指定则使用 `paddle.nn.PairWiseDistance` 分别计算得到正锚点与样本和负锚点与样本的距离。
1. 用户可传入distance_function参数,如果未指定则使用 `paddle.linalg.norm` 分别计算得到正锚点与样本和负锚点与样本的距离。
2. `swap` 参数判断:正锚点和负锚点间距离,并与负锚点与样本间距离进行比较,取更小的距离作为负锚点与样本间的距离。
3. 通过 `paddle.clip` 实现公式所示求出得 loss。

@@ -180,8 +180,7 @@ distance functions可以采用paddle.nn.PairWiseDistance来进行实现
- 2.CPU、GPU下计算一致。
- 3.各reduction下计算一致
- 4.各参数输入有效。
- 5.反向梯度的正确性。
-

# 七、可行性分析和排期规划
方案主要依赖现有paddle api组合而成,可以满足在当前版本周期内开发完成。