Skip to content

Commit

Permalink
fix axis examples
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Aug 5, 2022
1 parent 0072ec3 commit a09e913
Showing 1 changed file with 45 additions and 7 deletions.
52 changes: 45 additions & 7 deletions rfcs/CINN/APIs/20220802_cinn_api_design_one_hot.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

该算子根据输入的索引(`indices`),返回一个 Tensor,该 Tensor 将索引的位置标注为用户指定的一个值(`on_value`),非索引的位置标注为另一个值(`off_value`)。最常见的是 `1/0` 标值(`on_value=1``off_value=0`

```text
> 以下示例输出由 tvm.relay.one_hot 运算得到,由于完整代码过于复杂,因此使用伪码描述主要计算逻辑
```python
indices = [0, 2, 2] # 输入索引,shape = [3]
on_value = 1 # 索引位置的值
off_value = 0 # 非索引位置的值
Expand All @@ -36,7 +38,8 @@ one_hot(
indices,
on_value=on_value,
off_value=off_value,
depth,
depth=depth,
axis=axis,
dtype=dtype
)
# shape = [3, 3]
Expand All @@ -47,13 +50,14 @@ one_hot(

`depth` 用于表示总类别个数,用于表示在需要填充的轴上一共要填充多少个数据

```text
```python
# 其余参数不变
one_hot(
indices, # shape = [3]
on_value=on_value,
off_value=off_value,
depth=5,
axis=axis,
dtype=dtype
)
# shape = [3, 5]
Expand All @@ -64,13 +68,14 @@ one_hot(

`on_value``off_value` 索引位置的值和非索引位置的值,均只支持 0 维数据(Scalar)

```text
```python
# 其余参数不变
one_hot(
indices, # shape = [3]
on_value=233, # shape = []
off_value=-233, # shape = []
depth=depth,
axis=axis,
dtype=dtype
)
# shape = [3, 3]
Expand All @@ -83,13 +88,14 @@ one_hot(

`axis` 的范围为 `[-1, indices.ndim]`

```text
```python
# 其余参数不变
one_hot(
indices, # shape = [3]
on_value=on_value,
off_value=off_value,
depth=depth,
axis=0,
dtype=dtype
)
# shape = [3, 3]
Expand All @@ -98,9 +104,41 @@ one_hot(
# [0. 1. 1.]]
```

`indices` 支持任意维度 Tensor,输出数据维度为 `<indices outer dimensions> x depth x <indices inner dimensions>`
`indices` 支持任意维度 Tensor(包含 0 维数据),输出数据维度为 `<indices outer dimensions> x depth x <indices inner dimensions>`

```python
indices = 1
one_hot(
indices, # shape = []
on_value=on_value,
off_value=off_value,
depth=depth,
axis=axis,
dtype=dtype
)
# shape = [3]
# [0. 1. 0.]
```

> **Note**
>
> 由于现在 CINN 不支持 0 维数据,因此在 CINN 的实现中也是无法支持 0 维数据的,但 one_hot 功能本身应当支持 0 维数据,这与 `on_value``off_value` 相似,可以在 CINN 支持 0 维数据后再作考虑
下面是一个比较复杂的例子,输入维度为 `[A, B, C, D, E, F]`,如果 `axis = 2`,则,输出维度为 `[A, B, depth, C, D, E, F]`

比如输入维度为 `[A, B, C, D, E, F]`,如果 `axis = 2`,则,输出维度为 `[A, B, depth, C, D, E, F]`
```python
indices = np.random.randint(10, size=[2, 3, 4, 5, 6, 7])
one_hot(
indices, # shape = [2, 3, 4, 5, 6, 7]
on_value=on_value,
off_value=off_value,
depth=10,
axis=2,
dtype=dtype
)
# shape = [2, 3, 10, 4, 5, 6, 7]
# 具体输出太多,不作完整展示,shape 是直接通过 relay.one_hot 打印得到,可复现
```

### 4、意义

Expand Down

0 comments on commit a09e913

Please sign in to comment.