Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion CUDA Optimizations #14428

Merged
merged 28 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6d8402e
Add benchmark
tianleiwu Jan 25, 2023
0ac952e
Add GroupNorm fusion
tianleiwu Jan 25, 2023
c1de9fa
Merge branch 'main' into tlwu/optimize_sd
tianleiwu Jan 25, 2023
29c47dc
[CUDA] Add GroupNormalization operator
tianleiwu Jan 25, 2023
5f40aa8
Add Cast for fp16 group_norm
tianleiwu Jan 26, 2023
a4c4302
Add SplitGelu fusion
tianleiwu Jan 27, 2023
f722c5a
support float type in GroupNorm
tianleiwu Jan 27, 2023
4a7bf0d
Add SplitGelu operator
tianleiwu Jan 27, 2023
98b90ca
format
tianleiwu Jan 28, 2023
ea69aec
format
tianleiwu Jan 28, 2023
9eacd84
misc
tianleiwu Jan 29, 2023
c566679
update group norm test data to NHWC
tianleiwu Jan 29, 2023
a9ebeec
Fuse Bias and SplitGelu
tianleiwu Jan 29, 2023
53a539f
update bias split gelu
tianleiwu Jan 29, 2023
a0c4957
update GroupNorm doc
tianleiwu Jan 30, 2023
82383dc
packed kv in cross attention
tianleiwu Jan 31, 2023
966b3e7
fix pyright warnings
tianleiwu Jan 31, 2023
4a8583e
Add unit test of bias split gelu
tianleiwu Jan 31, 2023
982663a
fix typo
tianleiwu Jan 31, 2023
73045bb
fix code scanning warnings
tianleiwu Jan 31, 2023
86d5795
fix code scanning warnings
tianleiwu Jan 31, 2023
efa6d4f
address review feedback
tianleiwu Jan 31, 2023
7a75ce1
Add NhwcConv
tianleiwu Feb 1, 2023
f4d4103
fix training api build error
tianleiwu Feb 1, 2023
55a7468
Add float16 test
tianleiwu Feb 1, 2023
3ff1fe6
fix type warning
tianleiwu Feb 2, 2023
b3a4c01
update op doc; exclude from hipify
tianleiwu Feb 2, 2023
1fe78af
add input checks; clean debug code
tianleiwu Feb 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ set(contrib_ops_excluded_files
"bert/tensorrt_fused_multihead_attention/*"
"bert/transformer_common.h"
"bert/transformer_common.cc"
"diffusion/group_norm.h"
"diffusion/group_norm.cc"
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/bias_split_gelu_impl.h"
"diffusion/bias_split_gelu_impl.cu"
"diffusion/bias_split_gelu.h"
"diffusion/bias_split_gelu.cc"
"diffusion/nhwc_conv.cc"
"math/complex_mul.cc"
"math/complex_mul.h"
"math/complex_mul_impl.cu"
Expand Down
98 changes: 94 additions & 4 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Do not modify directly.*
* <a href="#com.microsoft.BiasDropout">com.microsoft.BiasDropout</a>
* <a href="#com.microsoft.BiasGelu">com.microsoft.BiasGelu</a>
* <a href="#com.microsoft.BiasSoftmax">com.microsoft.BiasSoftmax</a>
* <a href="#com.microsoft.BiasSplitGelu">com.microsoft.BiasSplitGelu</a>
* <a href="#com.microsoft.BifurcationDetector">com.microsoft.BifurcationDetector</a>
* <a href="#com.microsoft.BitmaskBiasDropout">com.microsoft.BitmaskBiasDropout</a>
* <a href="#com.microsoft.BitmaskDropout">com.microsoft.BitmaskDropout</a>
Expand All @@ -34,6 +35,7 @@ Do not modify directly.*
* <a href="#com.microsoft.GemmFastGelu">com.microsoft.GemmFastGelu</a>
* <a href="#com.microsoft.GreedySearch">com.microsoft.GreedySearch</a>
* <a href="#com.microsoft.GridSample">com.microsoft.GridSample</a>
* <a href="#com.microsoft.GroupNorm">com.microsoft.GroupNorm</a>
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
Expand Down Expand Up @@ -590,6 +592,39 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.BiasSplitGelu"></a><a name="com.microsoft.biassplitgelu">**com.microsoft.BiasSplitGelu**</a>

A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left
tensor multiplies the Gelu activation result of right tensor.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Inputs

<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension</dd>
<dt><tt>bias</tt> : T</dt>
<dd>Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>The output tensor with dimensions (N, S, D/2)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain input X and output Y types to float tensors.</dd>
</dl>


### <a name="com.microsoft.BifurcationDetector"></a><a name="com.microsoft.bifurcationdetector">**com.microsoft.BifurcationDetector**</a>

Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens,
Expand Down Expand Up @@ -1811,6 +1846,61 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.GroupNorm"></a><a name="com.microsoft.groupnorm">**com.microsoft.GroupNorm**</a>

Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494).

This operator transforms input according to
y = gamma * (x - mean) / sqrt(variance + epsilon) + beta

The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group.
The weight and bias are per-channel affine transform parameter vectors of size num_channels.

The activation attribute can be used to enable activation after group normalization.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>activation</tt> : int (required)</dt>
<dd>Activation after group normalization: 0 for None, 1 for Swish</dd>
<dt><tt>epsilon</tt> : float</dt>
<dd>The epsilon value to use to avoid division by zero</dd>
<dt><tt>groups</tt> : int (required)</dt>
<dd>The number of groups of channels. It should be a divisor of the number of channels C</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data</dd>
<dt><tt>gamma</tt> : M</dt>
<dd>1D gamma tensor for normalization with shape (C), where C is number of channels</dd>
<dt><tt>beta</tt> : M</dt>
<dd>1D beta tensor for normalization with shape (C), where C is number of channels</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>The output tensor of the same shape as X</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain input X and output Y types to float tensors.</dd>
<dt><tt>M</tt> : tensor(float)</dt>
<dd>Constrain gamma and beta to float tensors.</dd>
</dl>


### <a name="com.microsoft.Inverse"></a><a name="com.microsoft.inverse">**com.microsoft.Inverse**</a>

#### Version
Expand Down Expand Up @@ -2132,16 +2222,16 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
</dl>

#### Inputs (4 - 5)
#### Inputs (2 - 5)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size)</dd>
<dt><tt>value</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)</dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size)</dd>
<dt><tt>bias</tt> : T</dt>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
Expand Down
9 changes: 7 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ Do not modify directly.*
|||[11, 12]|**B** = tensor(bool)<br/> **I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 10]|**B** = tensor(bool)<br/> **I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|LpNormalization|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float)|
|LpPool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float)|
|LpPool|*in* X:**T**<br> *out* Y:**T**|18+|**T** = tensor(float)|
|||[11, 17]|**T** = tensor(float)|
|||[2, 10]|**T** = tensor(float)|
|MatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
Expand Down Expand Up @@ -789,6 +790,7 @@ Do not modify directly.*
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|BiasSoftmax|*in* data:**T**<br> *in* bias:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|BiasSplitGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BitmaskBiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)<br/> **T3** = tensor(uint32)|
|BitmaskDropout|*in* data:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)<br/> **T3** = tensor(uint32)|
|ComplexMul|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand All @@ -804,11 +806,13 @@ Do not modify directly.*
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* extra_add:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down Expand Up @@ -1086,7 +1090,8 @@ Do not modify directly.*
|Scatter|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
Loading