Skip to content

Commit

Permalink
Stable Diffusion CUDA Optimizations (#14428)
Browse files Browse the repository at this point in the history
### Description

Add stable diffusion CUDA kernel optimizations.

The following are included:
(1) GroupNorm operator. This kernel is from TensorRT 8.5.
(2) BiasSplitGelu operator. This kernel is modified from SplitGelu of
TensorRT 8.5. We added bias to the SplitGelu.
(3) NhwcConv operator. This adds support of NHWC format (ONNX Conv
operator uses NCHW format).
(3) Update MultiHeadAttention (packed kv and no bias) for cross
attention. This could avoid transpose of kv for TRT fused cross
attention kernel.
(4) Optimization and benchmark script

Not included:
(1) Script to convert Conv to NhwcConv in onnx graph.
(2) Update symbolic shape inference for NhwcConv.
(3) Add SeqLen2Spatial operator
(4) Documents

Limitations: GroupNorm, BiasSplitGelu and NhwcConv kernels are
implemented based on stable diffusion usage. They might not be
applicable to any input size or dimensions. For example, BiasSplitGelu
requires hidden size to be 2560 | 5120 | 10240, and NhwcConv assumes 4D
input/weight.

There is minor increasement of binary size. For SM=75 only, python
package wheel size adds (33757K - 33640K) = 117 KB. It is possible to
move NHWC from template parameter to constructor to reduce binary size
(with slight cost of performance).

Note: for RTX 4090/4080/4070 Ti, need build with CUDA 11.8 and latest
cuDNN to get best performance.
  • Loading branch information
tianleiwu authored and rui-ren committed Feb 3, 2023
1 parent 9b176dd commit f435199
Show file tree
Hide file tree
Showing 44 changed files with 3,263 additions and 328 deletions.
9 changes: 9 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ set(contrib_ops_excluded_files
"bert/tensorrt_fused_multihead_attention/*"
"bert/transformer_common.h"
"bert/transformer_common.cc"
"diffusion/group_norm.h"
"diffusion/group_norm.cc"
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/bias_split_gelu_impl.h"
"diffusion/bias_split_gelu_impl.cu"
"diffusion/bias_split_gelu.h"
"diffusion/bias_split_gelu.cc"
"diffusion/nhwc_conv.cc"
"math/complex_mul.cc"
"math/complex_mul.h"
"math/complex_mul_impl.cu"
Expand Down
98 changes: 94 additions & 4 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Do not modify directly.*
* <a href="#com.microsoft.BiasDropout">com.microsoft.BiasDropout</a>
* <a href="#com.microsoft.BiasGelu">com.microsoft.BiasGelu</a>
* <a href="#com.microsoft.BiasSoftmax">com.microsoft.BiasSoftmax</a>
* <a href="#com.microsoft.BiasSplitGelu">com.microsoft.BiasSplitGelu</a>
* <a href="#com.microsoft.BifurcationDetector">com.microsoft.BifurcationDetector</a>
* <a href="#com.microsoft.BitmaskBiasDropout">com.microsoft.BitmaskBiasDropout</a>
* <a href="#com.microsoft.BitmaskDropout">com.microsoft.BitmaskDropout</a>
Expand All @@ -34,6 +35,7 @@ Do not modify directly.*
* <a href="#com.microsoft.GemmFastGelu">com.microsoft.GemmFastGelu</a>
* <a href="#com.microsoft.GreedySearch">com.microsoft.GreedySearch</a>
* <a href="#com.microsoft.GridSample">com.microsoft.GridSample</a>
* <a href="#com.microsoft.GroupNorm">com.microsoft.GroupNorm</a>
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
Expand Down Expand Up @@ -590,6 +592,39 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.BiasSplitGelu"></a><a name="com.microsoft.biassplitgelu">**com.microsoft.BiasSplitGelu**</a>

A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left
tensor multiplies the Gelu activation result of right tensor.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Inputs

<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension</dd>
<dt><tt>bias</tt> : T</dt>
<dd>Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>The output tensor with dimensions (N, S, D/2)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain input X and output Y types to float tensors.</dd>
</dl>


### <a name="com.microsoft.BifurcationDetector"></a><a name="com.microsoft.bifurcationdetector">**com.microsoft.BifurcationDetector**</a>

Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens,
Expand Down Expand Up @@ -1811,6 +1846,61 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.GroupNorm"></a><a name="com.microsoft.groupnorm">**com.microsoft.GroupNorm**</a>

Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494).

This operator transforms input according to
y = gamma * (x - mean) / sqrt(variance + epsilon) + beta

The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group.
The weight and bias are per-channel affine transform parameter vectors of size num_channels.

The activation attribute can be used to enable activation after group normalization.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>activation</tt> : int (required)</dt>
<dd>Activation after group normalization: 0 for None, 1 for Swish</dd>
<dt><tt>epsilon</tt> : float</dt>
<dd>The epsilon value to use to avoid division by zero</dd>
<dt><tt>groups</tt> : int (required)</dt>
<dd>The number of groups of channels. It should be a divisor of the number of channels C</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data</dd>
<dt><tt>gamma</tt> : M</dt>
<dd>1D gamma tensor for normalization with shape (C), where C is number of channels</dd>
<dt><tt>beta</tt> : M</dt>
<dd>1D beta tensor for normalization with shape (C), where C is number of channels</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>The output tensor of the same shape as X</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain input X and output Y types to float tensors.</dd>
<dt><tt>M</tt> : tensor(float)</dt>
<dd>Constrain gamma and beta to float tensors.</dd>
</dl>


### <a name="com.microsoft.Inverse"></a><a name="com.microsoft.inverse">**com.microsoft.Inverse**</a>

#### Version
Expand Down Expand Up @@ -2132,16 +2222,16 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
</dl>

#### Inputs (4 - 5)
#### Inputs (2 - 5)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size)</dd>
<dt><tt>value</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)</dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size)</dd>
<dt><tt>bias</tt> : T</dt>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
Expand Down
3 changes: 3 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,7 @@ Do not modify directly.*
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|BiasSoftmax|*in* data:**T**<br> *in* bias:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|BiasSplitGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BitmaskBiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)<br/> **T3** = tensor(uint32)|
|BitmaskDropout|*in* data:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)<br/> **T3** = tensor(uint32)|
|ComplexMul|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand All @@ -805,11 +806,13 @@ Do not modify directly.*
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* extra_add:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
117 changes: 74 additions & 43 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
int max_threads_per_block) {
// query (Q) : (B, S, D)
// key (K) : (B, L, D)
// value (V) : (B, L, D_v)
// bias (Q/K/V) : (D + D + D_v)
// key_padding_mask (K/V) : (B, L) or (L)
// query (Q) : (B, S, D)
// key (K) : (B, L, D)
// value (V) : (B, L, D_v)
// bias (Q/K/V) : (D + D + D_v)
// key_padding_mask (K/V) : (B) or (B, L) or None
// When packed kv is used:
// key (K) : (B, L, N, 2, H)
// value (V) : None
// bias (Q/K/V) : None

const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3) {
Expand All @@ -34,15 +38,50 @@ Status CheckInputs(const T* query,
}

const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
if (key_dims.size() != 3 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
bias_dims.size());
int batch_size = static_cast<int>(query_dims[0]);
int sequence_length = static_cast<int>(query_dims[1]);
int hidden_size = static_cast<int>(query_dims[2]);
int head_size = static_cast<int>(hidden_size) / num_heads;
int kv_sequence_length = static_cast<int>(key_dims[1]);

if (key_dims.size() == 3) {
if (key_dims[2] != query_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 2 (hidden_size)");
}
} else // if (key_dims.size() == 5)
{
if (static_cast<int>(key_dims[2]) != num_heads || static_cast<int>(key_dims[3]) != 2 || static_cast<int>(key_dims[4]) != head_size) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
}
}

if (bias != nullptr) {
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
bias_dims.size());
}

// Currently, bias is not allowed for packed KV. This constraint can be removed later.
// Here we assume that fusion tool will not include bias for packed KV.
if (value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. ");
}
}

AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
Expand All @@ -61,47 +100,39 @@ Status CheckInputs(const T* query,
}
}

if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

int64_t batch_size = query_dims[0];
int64_t sequence_length = query_dims[1];
int64_t kv_sequence_length = key_dims[1];
int64_t q_hidden_size = query_dims[2];
int64_t v_hidden_size = 0;

const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}
int v_hidden_size = hidden_size;
if (value != nullptr) {
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}

if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}
if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}

if (key_dims[1] != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same same dim 1 (sequence_length)");
if (key_dims[1] != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)");
}
v_hidden_size = static_cast<int>(value_dims[2]);
}
v_hidden_size = value_dims[2];

if (parameters != nullptr) {
AttentionParameters* output_parameters = reinterpret_cast<AttentionParameters*>(parameters);
output_parameters->batch_size = static_cast<int>(batch_size);
output_parameters->sequence_length = static_cast<int>(sequence_length);
output_parameters->batch_size = batch_size;
output_parameters->sequence_length = sequence_length;
output_parameters->past_sequence_length = 0;
output_parameters->kv_sequence_length = static_cast<int>(kv_sequence_length);
output_parameters->total_sequence_length = static_cast<int>(kv_sequence_length);
output_parameters->kv_sequence_length = kv_sequence_length;
output_parameters->total_sequence_length = kv_sequence_length;
output_parameters->max_sequence_length = 0;
output_parameters->input_hidden_size = 0;
output_parameters->hidden_size = static_cast<int>(q_hidden_size);
output_parameters->v_hidden_size = static_cast<int>(v_hidden_size);
output_parameters->head_size = static_cast<int>(q_hidden_size) / num_heads;
output_parameters->v_head_size = static_cast<int>(v_hidden_size) / num_heads;
output_parameters->hidden_size = hidden_size;
output_parameters->v_hidden_size = v_hidden_size;
output_parameters->head_size = hidden_size / num_heads;
output_parameters->v_head_size = v_hidden_size / num_heads;
output_parameters->num_heads = num_heads;
output_parameters->is_unidirectional = false;
output_parameters->past_present_share_buffer = false;
Expand Down
Loading

0 comments on commit f435199

Please sign in to comment.