diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index d3b8f5ebfcc26..85246ec8bd37c 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -27,6 +27,15 @@ set(contrib_ops_excluded_files
"bert/tensorrt_fused_multihead_attention/*"
"bert/transformer_common.h"
"bert/transformer_common.cc"
+ "diffusion/group_norm.h"
+ "diffusion/group_norm.cc"
+ "diffusion/group_norm_impl.cu"
+ "diffusion/group_norm_impl.h"
+ "diffusion/bias_split_gelu_impl.h"
+ "diffusion/bias_split_gelu_impl.cu"
+ "diffusion/bias_split_gelu.h"
+ "diffusion/bias_split_gelu.cc"
+ "diffusion/nhwc_conv.cc"
"math/complex_mul.cc"
"math/complex_mul.h"
"math/complex_mul_impl.cu"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 1e6d46963cd21..8cd6d4c9e26f1 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -9,6 +9,7 @@ Do not modify directly.*
* com.microsoft.BiasDropout
* com.microsoft.BiasGelu
* com.microsoft.BiasSoftmax
+ * com.microsoft.BiasSplitGelu
* com.microsoft.BifurcationDetector
* com.microsoft.BitmaskBiasDropout
* com.microsoft.BitmaskDropout
@@ -34,6 +35,7 @@ Do not modify directly.*
* com.microsoft.GemmFastGelu
* com.microsoft.GreedySearch
* com.microsoft.GridSample
+ * com.microsoft.GroupNorm
* com.microsoft.Inverse
* com.microsoft.Irfft
* com.microsoft.LongformerAttention
@@ -590,6 +592,39 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.BiasSplitGelu**
+
+ A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left
+ tensor multiplies the Gelu activation result of right tensor.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Inputs
+
+
+- X : T
+- Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension
+- bias : T
+- Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor
+
+
+#### Outputs
+
+
+- Y : T
+- The output tensor with dimensions (N, S, D/2)
+
+
+#### Type Constraints
+
+
+- T : tensor(float16), tensor(float)
+- Constrain input X and output Y types to float tensors.
+
+
+
### **com.microsoft.BifurcationDetector**
Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens,
@@ -1811,6 +1846,61 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.GroupNorm**
+
+ Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494).
+
+ This operator transforms input according to
+ y = gamma * (x - mean) / sqrt(variance + epsilon) + beta
+
+ The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group.
+ The weight and bias are per-channel affine transform parameter vectors of size num_channels.
+
+ The activation attribute can be used to enable activation after group normalization.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- activation : int (required)
+- Activation after group normalization: 0 for None, 1 for Swish
+- epsilon : float
+- The epsilon value to use to avoid division by zero
+- groups : int (required)
+- The number of groups of channels. It should be a divisor of the number of channels C
+
+
+#### Inputs
+
+
+- X : T
+- Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+- gamma : M
+- 1D gamma tensor for normalization with shape (C), where C is number of channels
+- beta : M
+- 1D beta tensor for normalization with shape (C), where C is number of channels
+
+
+#### Outputs
+
+
+- Y : T
+- The output tensor of the same shape as X
+
+
+#### Type Constraints
+
+
+- T : tensor(float16), tensor(float)
+- Constrain input X and output Y types to float tensors.
+- M : tensor(float)
+- Constrain gamma and beta to float tensors.
+
+
+
### **com.microsoft.Inverse**
#### Version
@@ -2132,16 +2222,16 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (4 - 5)
+#### Inputs (2 - 5)
- query : T
- Query with shape (batch_size, sequence_length, hidden_size)
- key : T
-- Key with shape (batch_size, kv_sequence_length, hidden_size)
-- value : T
+- Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
+- value (optional) : T
- Value with shape (batch_size, kv_sequence_length, v_hidden_size)
-- bias : T
+- bias (optional) : T
- Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
- key_padding_mask (optional) : M
- Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index ad571dacb20d7..7e4eb38be780b 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -175,7 +175,8 @@ Do not modify directly.*
|||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 10]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|LpNormalization|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float)|
-|LpPool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float)|
+|LpPool|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(float)|
+|||[11, 17]|**T** = tensor(float)|
|||[2, 10]|**T** = tensor(float)|
|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
@@ -789,6 +790,7 @@ Do not modify directly.*
|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BitmaskBiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)|
|BitmaskDropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)|
|ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -804,11 +806,13 @@ Do not modify directly.*
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
+|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
+|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* extra_add:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)|
@@ -1086,7 +1090,8 @@ Do not modify directly.*
|Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
-|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index ce109a83720b9..8c3af05972c95 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -21,11 +21,15 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
int max_threads_per_block) {
- // query (Q) : (B, S, D)
- // key (K) : (B, L, D)
- // value (V) : (B, L, D_v)
- // bias (Q/K/V) : (D + D + D_v)
- // key_padding_mask (K/V) : (B, L) or (L)
+ // query (Q) : (B, S, D)
+ // key (K) : (B, L, D)
+ // value (V) : (B, L, D_v)
+ // bias (Q/K/V) : (D + D + D_v)
+ // key_padding_mask (K/V) : (B) or (B, L) or None
+ // When packed kv is used:
+ // key (K) : (B, L, N, 2, H)
+ // value (V) : None
+ // bias (Q/K/V) : None
const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3) {
@@ -34,15 +38,50 @@ Status CheckInputs(const T* query,
}
const auto& key_dims = key->Shape().GetDims();
- if (key_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
+ if (key_dims.size() != 3 && key_dims.size() != 5) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
key_dims.size());
}
+ if (query_dims[0] != key_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 0 (batch size)");
+ }
- const auto& bias_dims = bias->Shape().GetDims();
- if (bias_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
- bias_dims.size());
+ int batch_size = static_cast(query_dims[0]);
+ int sequence_length = static_cast(query_dims[1]);
+ int hidden_size = static_cast(query_dims[2]);
+ int head_size = static_cast(hidden_size) / num_heads;
+ int kv_sequence_length = static_cast(key_dims[1]);
+
+ if (key_dims.size() == 3) {
+ if (key_dims[2] != query_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 2 (hidden_size)");
+ }
+ } else // if (key_dims.size() == 5)
+ {
+ if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
+ }
+ if (value != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
+ }
+ }
+
+ if (bias != nullptr) {
+ const auto& bias_dims = bias->Shape().GetDims();
+ if (bias_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
+ bias_dims.size());
+ }
+
+ // Currently, bias is not allowed for packed KV. This constraint can be removed later.
+ // Here we assume that fusion tool will not include bias for packed KV.
+ if (value == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. ");
+ }
}
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
@@ -61,47 +100,39 @@ Status CheckInputs(const T* query,
}
}
- if (query_dims[0] != key_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 0 (batch size)");
- }
-
- int64_t batch_size = query_dims[0];
- int64_t sequence_length = query_dims[1];
- int64_t kv_sequence_length = key_dims[1];
- int64_t q_hidden_size = query_dims[2];
- int64_t v_hidden_size = 0;
-
- const auto& value_dims = value->Shape().GetDims();
- if (value_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
- value_dims.size());
- }
+ int v_hidden_size = hidden_size;
+ if (value != nullptr) {
+ const auto& value_dims = value->Shape().GetDims();
+ if (value_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
+ value_dims.size());
+ }
- if (query_dims[0] != value_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'value' shall have same dim 0 (batch_size)");
- }
+ if (query_dims[0] != value_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 0 (batch_size)");
+ }
- if (key_dims[1] != value_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key' and 'value' shall have same same dim 1 (sequence_length)");
+ if (key_dims[1] != value_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)");
+ }
+ v_hidden_size = static_cast(value_dims[2]);
}
- v_hidden_size = value_dims[2];
if (parameters != nullptr) {
AttentionParameters* output_parameters = reinterpret_cast(parameters);
- output_parameters->batch_size = static_cast(batch_size);
- output_parameters->sequence_length = static_cast(sequence_length);
+ output_parameters->batch_size = batch_size;
+ output_parameters->sequence_length = sequence_length;
output_parameters->past_sequence_length = 0;
- output_parameters->kv_sequence_length = static_cast(kv_sequence_length);
- output_parameters->total_sequence_length = static_cast(kv_sequence_length);
+ output_parameters->kv_sequence_length = kv_sequence_length;
+ output_parameters->total_sequence_length = kv_sequence_length;
output_parameters->max_sequence_length = 0;
output_parameters->input_hidden_size = 0;
- output_parameters->hidden_size = static_cast(q_hidden_size);
- output_parameters->v_hidden_size = static_cast(v_hidden_size);
- output_parameters->head_size = static_cast(q_hidden_size) / num_heads;
- output_parameters->v_head_size = static_cast(v_hidden_size) / num_heads;
+ output_parameters->hidden_size = hidden_size;
+ output_parameters->v_hidden_size = v_hidden_size;
+ output_parameters->head_size = hidden_size / num_heads;
+ output_parameters->v_head_size = v_hidden_size / num_heads;
output_parameters->num_heads = num_heads;
output_parameters->is_unidirectional = false;
output_parameters->past_present_share_buffer = false;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index cf1d99688546a..630c533c47323 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -9,10 +9,6 @@
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
-#ifndef NDEBUG
-//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search
-#endif
-
namespace onnxruntime {
namespace concurrency {
@@ -57,14 +53,14 @@ struct IBeamSearchCpuState {
template
struct IGreedySearchState {
- gsl::span sequences_space; // shape (2, batch_size, max_length)
- gsl::span sequence_lengths; // shape (batch_size)
- gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids.
- gsl::span eos_meet; // shape (batch_size)
- gsl::span next_token_scores; // shape (batch_size, vocab_size)
- gsl::span next_tokens; // shape (batch_size)
- gsl::span temp_topk_scores_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1 (GPU only)
- gsl::span temp_topk_tokens_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1(GPU only)
+ gsl::span sequences_space; // shape (2, batch_size, max_length)
+ gsl::span sequence_lengths; // shape (batch_size)
+ gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids.
+ gsl::span eos_meet; // shape (batch_size)
+ gsl::span next_token_scores; // shape (batch_size, vocab_size)
+ gsl::span next_tokens; // shape (batch_size)
+ gsl::span temp_topk_scores_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1 (GPU only)
+ gsl::span temp_topk_tokens_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1(GPU only)
gsl::span topk_scores_buffer; // shape (batch_size), output buffer for topk stage 2 (GPU only)
gsl::span topk_tokens_buffer; // shape (batch_size), output buffer for topk stage 2 (GPU only)
};
@@ -167,6 +163,26 @@ struct IGenerationParameters {
bool custom_sampling = false;
};
+// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)
+#ifdef DEBUG_GENERATION
+#define DUMP_TENSOR_LEVEL 2
+#else
+#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation.
+#endif
+
+#if DUMP_TENSOR_LEVEL > 0
+#define DUMP_TENSOR_INIT() transformers::CudaTensorConsoleDumper dumper
+#define DUMP_TENSOR(...) dumper.Print(__VA_ARGS__)
+#else
+#define DUMP_TENSOR_INIT()
+#define DUMP_TENSOR(...)
+#endif
+#if DUMP_TENSOR_LEVEL > 1
+#define DUMP_TENSOR_D(...) dumper.Print(__VA_ARGS__)
+#else
+#define DUMP_TENSOR_D(...)
+#endif
+
class IConsoleDumper {
public:
IConsoleDumper() : is_enabled_(true) {}
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
index b7eebb9d48785..e86736726c224 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
@@ -366,6 +366,39 @@ __global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* outp
}
}
+template
+__global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output) {
+ // Format 4 to unpack TRT packed input format for memory efficient attention.
+ // Input: BxSxNxMxH
+ // Output: MxBxSxNxH
+ // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
+ int n = threadIdx.y;
+ int s = blockIdx.x;
+ int b = blockIdx.y;
+ int m = blockIdx.z; // matrix id
+
+ const int head_size = blockDim.x;
+ const int num_heads = blockDim.y;
+
+ const int sequence_length = gridDim.x;
+ const int batch_size = gridDim.y;
+ const int H = head_size;
+ const int NH = num_heads * head_size;
+ const int NHS = NH * sequence_length;
+
+ int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M;
+ const int out_offset = n * head_size + s * NH + b * NHS + m * NHS * batch_size;
+
+ const int h = threadIdx.x;
+ if (h < head_size) {
+ if (biases != nullptr) {
+ output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
+ } else {
+ output[out_offset + h] = input[in_offset + h];
+ }
+ }
+}
+
template
__global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) {
// Format 3 for cutlass memory efficient attention
@@ -481,7 +514,6 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const
}
}
-
template
void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
@@ -506,7 +538,9 @@ void InvokeAddBiasTranspose(
ORT_ENFORCE(total_matrix_count == 3);
AddBiasTransposeCutlass<<>>(input, biases, output, v_head_size);
}
- } else { // format == 0
+ } else if (format == 4) { // format == 4
+ AddBiasUnpack<<>>(total_matrix_count, input, biases, output);
+ } else { // format == 0
AddBiasTranspose<<>>(input, biases, output);
}
} else {
@@ -528,6 +562,8 @@ void InvokeAddBiasTranspose(
} else {
ORT_THROW("AddBiasTranspose (format 3) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size");
}
+ } else if (format == 4) { // format == 4
+ ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block");
} else { // format 0
AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output);
}
@@ -551,7 +587,7 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2,
qkv_add_bias2, H_v, total_matrix_count);
- } else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) {
+ } else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) {
const int H = qk_head_size / 2;
const int H_v = v_head_size / 2;
const half2* input2 = reinterpret_cast(input);
@@ -610,7 +646,6 @@ void InvokeAddBiasTransposeTrt(
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const T* biases, const T* query, const T* key, const T* value, T* output,
bool is_cross_attention, int kv_sequence_length) {
-
if (!is_cross_attention) {
ORT_ENFORCE(sequence_length == kv_sequence_length);
constexpr int num_matrices = 3;
@@ -696,52 +731,51 @@ void LaunchAddBiasTransposeTrt(
}
}
-
template
void InvokeAddBias(
cudaStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int kv_sequence_length,
const int num_heads, const int head_size, const int v_head_size,
const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v) {
- constexpr int num_matrices = 1;
- // Q
- {
- const dim3 grid(sequence_length, batch_size, num_matrices);
- if (head_size * num_heads <= max_threads_per_block) {
- const dim3 block(head_size, num_heads, 1);
- AddBiasTransposeTrt<<>>(query, biases, q);
- } else {
- const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
- AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q);
- }
+ constexpr int num_matrices = 1;
+ // Q
+ {
+ const dim3 grid(sequence_length, batch_size, num_matrices);
+ if (head_size * num_heads <= max_threads_per_block) {
+ const dim3 block(head_size, num_heads, 1);
+ AddBiasTransposeTrt<<>>(query, biases, q);
+ } else {
+ const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
+ AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q);
}
- // K
- {
- const dim3 grid(kv_sequence_length, batch_size, num_matrices);
- const T* biases_k = biases + num_heads * head_size;
+ }
+ // K
+ {
+ const dim3 grid(kv_sequence_length, batch_size, num_matrices);
+ const T* biases_k = biases + num_heads * head_size;
- if (head_size * num_heads <= max_threads_per_block) {
- const dim3 block(head_size, num_heads, 1);
- AddBiasTransposeTrt<<>>(key, biases_k, k);
- } else {
- const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
- AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k);
- }
+ if (head_size * num_heads <= max_threads_per_block) {
+ const dim3 block(head_size, num_heads, 1);
+ AddBiasTransposeTrt<<>>(key, biases_k, k);
+ } else {
+ const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
+ AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k);
}
+ }
- // V
- {
- const dim3 grid(kv_sequence_length, batch_size, num_matrices);
+ // V
+ {
+ const dim3 grid(kv_sequence_length, batch_size, num_matrices);
- const T* biases_v = biases + 2 * num_heads * head_size;
- if (v_head_size * num_heads <= max_threads_per_block) {
- const dim3 block(v_head_size, num_heads, 1);
- AddBiasTransposeTrt<<>>(value, biases_v, v);
- } else {
- const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
- AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v);
- }
+ const T* biases_v = biases + 2 * num_heads * head_size;
+ if (v_head_size * num_heads <= max_threads_per_block) {
+ const dim3 block(v_head_size, num_heads, 1);
+ AddBiasTransposeTrt<<>>(value, biases_v, v);
+ } else {
+ const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
+ AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v);
}
+ }
}
template <>
@@ -750,7 +784,7 @@ void LaunchAddBias(
const int batch_size, const int sequence_length, const int kv_sequence_length,
const int num_heads, const int head_size, const int v_head_size,
const float* biases, const float* query, const float* key, const float* value, float* q, float* k, float* v) {
-if (0 == (head_size % 4) && 0 == (v_head_size % 4)) {
+ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) {
const int H = head_size / 4;
const int H_v = v_head_size / 4;
const float4* query2 = reinterpret_cast(query);
@@ -761,8 +795,8 @@ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) {
float4* k2 = reinterpret_cast(k);
float4* v2 = reinterpret_cast(v);
InvokeAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v,
- biases2, query2, key2, value2, q2, k2, v2);
+ batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v,
+ biases2, query2, key2, value2, q2, k2, v2);
} else if (0 == (head_size & 1) && 0 == (v_head_size & 1)) {
const int H = head_size / 2;
const int H_v = v_head_size / 2;
@@ -774,14 +808,13 @@ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) {
float2* k2 = reinterpret_cast(k);
float2* v2 = reinterpret_cast(v);
InvokeAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v,
- biases2, query2, key2, value2, q2, k2, v2);
+ batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v,
+ biases2, query2, key2, value2, q2, k2, v2);
} else {
InvokeAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length, num_heads, head_size, v_head_size,
- biases, query, key, value, q, k, v);
+ batch_size, sequence_length, kv_sequence_length, num_heads, head_size, v_head_size,
+ biases, query, key, value, q, k, v);
}
-
}
template <>
@@ -790,8 +823,7 @@ void LaunchAddBias(
const int batch_size, const int sequence_length, const int kv_sequence_length,
const int num_heads, const int head_size, const int v_head_size,
const half* biases, const half* query, const half* key, const half* value, half* q, half* k, half* v) {
-
- if (0 == (head_size % 4) && 0 == (v_head_size % 4)) {
+ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) {
const int H = head_size / 4;
const int H_v = v_head_size / 4;
const Half4* query2 = reinterpret_cast(query);
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
index 8cc36637054e7..a2c3265284a4d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
@@ -24,6 +24,10 @@ namespace cuda {
// format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3)
// input: (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (num_matrices, batch_size, sequence_length, num_heads, head_size)
+// format 4: (requires qk_head_size = v_head_size)
+// input: (batch_size, sequence_length, num_heads, num_matrices, head_size)
+// output: (num_matrices, batch_size, sequence_length, num_heads, head_size)
+
template
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 187f1bb37edc5..8c7ef9f919519 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -48,21 +48,6 @@ using namespace cub;
#define CHECK_CUDA(expr) CUDA_RETURN_IF_ERROR(expr)
#define CUDA_MEMORY_ALIGNMENT 256
-#define DUMP_ATTENTION_LEVEL 0
-#if DUMP_ATTENTION_LEVEL > 1
-#define DUMP_ATTENTION_INIT() transformers::CudaTensorConsoleDumper dumper
-#define DUMP_ATTENTION(...) dumper.Print(__VA_ARGS__)
-#define DUMP_ATTENTION_D(...) dumper.Print(__VA_ARGS__)
-#elif DUMP_ATTENTION_LEVEL > 0
-#define DUMP_ATTENTION_INIT() transformers::CudaTensorConsoleDumper dumper
-#define DUMP_ATTENTION(...) dumper.Print(__VA_ARGS__)
-#define DUMP_ATTENTION_D(...)
-#else
-#define DUMP_ATTENTION_INIT()
-#define DUMP_ATTENTION(...)
-#define DUMP_ATTENTION_D(...)
-#endif
-
namespace onnxruntime {
namespace contrib {
namespace cuda {
@@ -283,7 +268,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
// Default format for memory efficient attention.
// When there is past state, the format shal be BxNxSxH, so we disable memory efficient attention when there is past.
- DUMP_ATTENTION_INIT();
+ DUMP_TENSOR_INIT();
if (nullptr != data.gemm_buffer) {
if (data.bias == nullptr) {
// For quantized attention, bias has been added so only need transpose here.
@@ -317,15 +302,42 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
data.gemm_buffer, data.bias, qkv,
true, v_head_size, qkv_add_bias, 3);
}
- } else { // gemm_buffer == nullptr
+ } else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv
+ // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
+ // CheckInputs verified this constraint.
+ assert(data.bias == nullptr);
+ assert(qk_head_size == v_head_size);
+
+ DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size);
+
+ if (use_memory_efficient_attention) {
+ // unpack kv to BSNH. Note that there is no bias so we need not output query to q.
+ constexpr int format = 4;
+ T* qkv_add_bias = nullptr;
+ const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
+ LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, kv_bias, k,
+ true, v_head_size, qkv_add_bias, 2);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ } else {
+ if (data.fused_cross_attention_kernel == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options.");
+ }
+
+ qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
+ }
+ } else { // gemm_buffer == nullptr and not packed kv
assert(data.query != nullptr && data.key != nullptr && data.value != nullptr && data.bias != nullptr);
- DUMP_ATTENTION_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size);
- DUMP_ATTENTION_D("query_bias", data.bias, num_heads, qk_head_size);
- DUMP_ATTENTION_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size);
- DUMP_ATTENTION_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
- DUMP_ATTENTION_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size);
- DUMP_ATTENTION_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
+ DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size);
+ DUMP_TENSOR_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
+ DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
@@ -347,9 +359,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, q, k, v);
- DUMP_ATTENTION_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
- DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
- DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
#endif
@@ -362,7 +374,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length);
- DUMP_ATTENTION_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
+ DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
qkv_format = AttentionQkvFormat::QKV_BSN3H;
} else { // unfused kernel
@@ -387,9 +399,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
data.value, data.bias + 2 * num_heads * qk_head_size, v,
true, -1);
- DUMP_ATTENTION_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size);
- DUMP_ATTENTION_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size);
- DUMP_ATTENTION_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size);
+ DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size);
+ DUMP_TENSOR_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
}
@@ -419,8 +431,7 @@ Status QkvToContext(
void* fused_runner = data.fused_runner;
// At most one fused kernel is enabled.
- assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) +
- int(data.fused_cross_attention_kernel != nullptr) <= 1);
+ assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1);
const int batches = batch_size * num_heads;
const int size_per_batch_q = sequence_length * qk_head_size;
@@ -481,7 +492,7 @@ Status QkvToContext(
ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent(
stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length,
batch_size, qk_head_size, num_heads, max_threads_per_block,
- use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer
+ use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer
data.gemm_buffer, data.present));
present_size_per_batch_k = parameters.max_sequence_length * qk_head_size;
@@ -491,7 +502,7 @@ Status QkvToContext(
}
// Q, K and V are ready now
- DUMP_ATTENTION_INIT();
+ DUMP_TENSOR_INIT();
if (data.fused_cross_attention_kernel != nullptr) {
assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H);
@@ -499,7 +510,7 @@ Status QkvToContext(
LaunchTrtSequenceOffset(q_sequence_offset, nullptr, batch_size, sequence_length, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
- DUMP_ATTENTION_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1);
+ DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1);
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
@@ -509,26 +520,34 @@ Status QkvToContext(
LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, kv_sequence_length, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
- DUMP_ATTENTION_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1);
+ DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1);
FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel =
reinterpret_cast(data.fused_cross_attention_kernel);
+ // When there is no bias, we can directly use q and packed kv from inputs. TODO: not need qkv in workspace.
+ void const* query = q;
+ void const* packed_kv = k;
+ if (data.value == nullptr && data.bias == nullptr) {
+ query = data.query;
+ packed_kv = data.key;
+ }
+
run_fused_cross_attention(
- q, // Q
- k, // packed KV
- q_sequence_offset, // cumulated sequence length of Q
- kv_sequence_offset, // cumulated sequence length of KV
- data.output, // output
- cross_attention_kernel, // kernels
- batch_size, // batch size
- num_heads, // number of heads
- qk_head_size, // head size of Q/K/V
- sequence_length, // sequence length of Q
- kv_sequence_length, // sequence length of KV
+ query, // Q
+ packed_kv, // packed KV
+ q_sequence_offset, // cumulated sequence length of Q
+ kv_sequence_offset, // cumulated sequence length of KV
+ data.output, // output
+ cross_attention_kernel, // kernels
+ batch_size, // batch size
+ num_heads, // number of heads
+ qk_head_size, // head size of Q/K/V
+ sequence_length, // sequence length of Q
+ kv_sequence_length, // sequence length of KV
stream);
- DUMP_ATTENTION("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size);
return Status::OK();
}
@@ -554,11 +573,11 @@ Status QkvToContext(
if (use_fused_kernel) {
assert(qkv_format == AttentionQkvFormat::QKV_BSN3H);
fused_fp16_runner->run(qkv, sequence_offset, data.output, stream);
- DUMP_ATTENTION("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size);
} else {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
- DUMP_ATTENTION("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size);
}
return Status::OK();
}
@@ -570,6 +589,13 @@ Status QkvToContext(
assert(data.mask_index == nullptr);
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ const void* query = q;
+ const void* key = k;
+ const void* value = v;
+ if (data.gemm_buffer == nullptr && data.value == nullptr) { // packed KV
+ query = data.query;
+ }
+
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
p.is_half = sizeof(T) == 2;
@@ -582,15 +608,15 @@ Status QkvToContext(
p.causal = parameters.is_unidirectional;
p.cu_seqlens_q = nullptr;
p.cu_seqlens_k = nullptr;
- p.query = q;
- p.key = k;
- p.value = v;
+ p.query = query;
+ p.key = key;
+ p.value = value;
p.output = data.output;
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr;
p.stream = stream;
run_memory_efficient_attention(p);
- DUMP_ATTENTION("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size);
return Status::OK();
}
#endif
@@ -610,7 +636,7 @@ Status QkvToContext(
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size))
- : parameters.scale;
+ : parameters.scale;
float alpha = use_raw_attention_mask ? one : scale;
cublasSetStream(cublas, stream);
@@ -622,7 +648,7 @@ Status QkvToContext(
q, qk_head_size, sequence_length * qk_head_size,
&zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
- DUMP_ATTENTION_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length);
+ DUMP_TENSOR_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
sequence_length, total_sequence_length);
@@ -656,7 +682,7 @@ Status QkvToContext(
scratch1, scratch2, parameters.is_unidirectional));
}
- DUMP_ATTENTION_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length);
+ DUMP_TENSOR_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length);
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = qkv;
@@ -670,7 +696,7 @@ Status QkvToContext(
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, temp_output, data.output);
- DUMP_ATTENTION("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size);
return result;
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index c7e5d34e1691b..93e5e59ed00ae 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -94,6 +94,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
bool use_fused_cross_attention = !disable_fused_cross_attention_ &&
nullptr == key_padding_mask &&
+ (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV
parameters.hidden_size == parameters.v_hidden_size &&
has_fused_cross_attention_kernel(sm, parameters.head_size,
parameters.kv_sequence_length);
@@ -111,6 +112,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
bool use_fused_runner = !disable_fused_runner_ &&
fused_cross_attention_kernel == nullptr &&
+ value != nullptr && // fused runner requires packed qkv instead of packed kv
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
@@ -162,10 +164,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType::MappedType CudaT;
AttentionData data;
data.gemm_buffer = nullptr;
- data.bias = reinterpret_cast(bias->Data());
+ data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data());
data.query = reinterpret_cast(query->Data());
data.key = reinterpret_cast(key->Data());
- data.value = reinterpret_cast(value->Data());
+ data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data());
data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data();
data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims();
data.past = nullptr;
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 38bcbc298b939..a239e528af148 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -19,6 +19,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
@@ -71,6 +73,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler);
@@ -144,6 +149,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -192,6 +199,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc
new file mode 100644
index 0000000000000..2b13cdbd803ef
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc
@@ -0,0 +1,76 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cuda/diffusion/bias_split_gelu.h"
+#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ BiasSplitGelu, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ BiasSplitGelu);
+
+REGISTER_KERNEL_TYPED(MLFloat16);
+REGISTER_KERNEL_TYPED(float);
+
+using namespace ONNX_NAMESPACE;
+
+template
+BiasSplitGelu::BiasSplitGelu(const OpKernelInfo& op_info) : CudaKernel(op_info) {
+}
+
+template
+Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+
+ const auto& input_dims = input->Shape().GetDims();
+ if (input_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "input is expected to have 3 dimensions, got ", input_dims.size());
+ }
+
+ if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "hidden size should be 2560, 5120 or 10240, got ", input_dims[2]);
+ }
+
+ const Tensor* bias = context->Input(1);
+ const auto& bias_dims = bias->Shape().GetDims();
+ if (bias_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "bias is expected to have 1 dimensions, got ", bias_dims.size());
+ }
+ if (bias_dims[0] != input_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "last dimension of input and bias are not the same");
+ }
+
+ TensorShapeVector output_shape = input->Shape().AsShapeVector();
+ output_shape[2] = input_dims[2] / 2;
+ Tensor* output = context->Output(0, output_shape);
+
+ typedef typename ToCudaType::MappedType CudaT;
+ const int32_t grid_size = static_cast(input_dims[0] * input_dims[1]);
+ const int32_t half_hidden_size = static_cast(input_dims[2] / 2);
+ LaunchBiasSplitGeluKernel(Stream(context), grid_size, half_hidden_size,
+ reinterpret_cast(input->Data()),
+ reinterpret_cast(bias->Data()),
+ reinterpret_cast(output->MutableData()));
+
+ CUDA_RETURN_IF_ERROR(cudaPeekAtLastError());
+ return Status::OK();
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h
new file mode 100644
index 0000000000000..feec45600bbce
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class BiasSplitGelu final : public CudaKernel {
+ public:
+ BiasSplitGelu(const OpKernelInfo& op_kernel_info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu
new file mode 100644
index 0000000000000..3cb95dad26b36
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// The CUDA kernel is modified from SplitGelu plugin of TensorRT 8.5.
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+__global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) {
+ int32_t index_input = blockIdx.x * HHS * 2 + threadIdx.x;
+ int32_t index_output = blockIdx.x * HHS + threadIdx.x;
+ int32_t index_bias = threadIdx.x;
+
+#pragma unroll
+ for (int32_t i = 0; i < HHS / TPB; ++i) {
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
+ auto value_left = (float)(input[index_input] + bias[index_bias]);
+ auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]);
+#else
+ auto value_left = (float)(input[index_input]) + (float)(bias[index_bias]);
+ auto value_right = (float)(input[index_input + HHS]) + (float)(bias[index_bias + HHS]);
+#endif
+ // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0)
+ float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f);
+ float result = value_left * gelu_right;
+ output[index_output] = static_cast(result);
+ index_input += TPB;
+ index_output += TPB;
+ index_bias += TPB;
+ }
+ return;
+}
+
+template
+void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size,
+ T const* input, T const* bias, T* output) {
+ constexpr int32_t TPB = 256; // thread per block
+ switch (half_hidden_size) {
+ case 1280:
+ (biasSplitGeluKernel)<<>>(input, bias, output);
+ break;
+ case 2560:
+ (biasSplitGeluKernel)<<>>(input, bias, output);
+ break;
+ case 5120:
+ (biasSplitGeluKernel)<<>>(input, bias, output);
+ break;
+ default:
+ ORT_NOT_IMPLEMENTED("Not implemented");
+ }
+}
+
+template __global__ void biasSplitGeluKernel(float const*, float const*, float*);
+template __global__ void biasSplitGeluKernel(float const*, float const*, float*);
+template __global__ void biasSplitGeluKernel(float const*, float const*, float*);
+template __global__ void biasSplitGeluKernel(half const*, half const*, half*);
+template __global__ void biasSplitGeluKernel(half const*, half const*, half*);
+template __global__ void biasSplitGeluKernel(half const*, half const*, half*);
+
+template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size,
+ float const* input, float const* bias, float* output);
+
+template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size,
+ half const* input, half const* bias, half* output);
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h
new file mode 100644
index 0000000000000..a04201bd12e3c
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h
@@ -0,0 +1,19 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/common/status.h"
+#include
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size,
+ T const* input, T const* bias, T* output);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
new file mode 100644
index 0000000000000..36a2bd11257d6
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
@@ -0,0 +1,129 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cuda/diffusion/group_norm.h"
+#include "contrib_ops/cuda/diffusion/group_norm_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define GROUP_NORM_TYPES float, MLFloat16
+
+ONNX_OPERATOR_KERNEL_EX(
+ GroupNorm, kMSDomain, 1, kCudaExecutionProvider,
+ (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm);
+
+using namespace ONNX_NAMESPACE;
+
+namespace {
+template
+struct DispatchGroupNorm {
+ Status operator()(cudaStream_t stream,
+ Tensor* output,
+ const Tensor* input,
+ const Tensor* gamma,
+ const Tensor* beta,
+ void* workspace,
+ float epsilon,
+ int batch_size,
+ int num_channels,
+ int height,
+ int width,
+ int num_groups,
+ bool use_swish_activation) {
+ typedef typename ToCudaType::MappedType CudaT;
+ return LaunchGroupNormKernel(
+ stream,
+ reinterpret_cast(output->MutableData()),
+ reinterpret_cast(input->Data()),
+ gamma->Data(),
+ beta->Data(),
+ workspace,
+ epsilon,
+ batch_size,
+ num_channels,
+ height,
+ width,
+ num_groups,
+ use_swish_activation);
+ }
+};
+
+} // namespace
+
+GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
+ epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f);
+ ORT_ENFORCE(epsilon_ >= 0);
+
+ int64_t num_groups;
+ ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK());
+ ORT_ENFORCE(num_groups >= 0);
+ num_groups_ = static_cast(num_groups);
+
+ int64_t activation;
+ ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK());
+ ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish
+ use_swish_activation_ = (activation == 1);
+}
+
+Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const Tensor* gamma = context->Input(1);
+ const Tensor* beta = context->Input(2);
+ Tensor* output = context->Output(0, input->Shape());
+
+ const auto& input_dims = input->Shape().GetDims();
+ if (input_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "input is expected to have 4 dimensions, got ", input_dims.size());
+ }
+
+ const auto& gamma_dims = gamma->Shape().GetDims();
+ if (gamma_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "gamma is expected to have 1 dimension, got ", gamma_dims.size());
+ }
+ if (gamma_dims[0] != input_dims[3]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Number of channels in gamma and input does not match");
+ }
+
+ const auto& beta_dims = beta->Shape().GetDims();
+ if (beta_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "beta is expected to have 1 dimension, got ", beta_dims.size());
+ }
+ if (beta_dims[0] != input_dims[3]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Number of channels in beta and input does not match");
+ }
+
+ // Input and output format is NHWC
+ int batch_size = static_cast(input_dims[0]);
+ int num_channels = static_cast(input_dims[3]);
+ int height = static_cast(input_dims[1]);
+ int width = static_cast(input_dims[2]);
+
+ if (num_channels % num_groups_ != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "number of channels should be divisiable by num_groups");
+ }
+
+ auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
+
+ utils::MLTypeCallDispatcher dispatcher(input->GetElementType());
+ return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(),
+ epsilon_,
+ batch_size,
+ num_channels,
+ height,
+ width,
+ num_groups_,
+ use_swish_activation_);
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
new file mode 100644
index 0000000000000..8578a1642198f
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+class GroupNorm final : public CudaKernel {
+ public:
+ GroupNorm(const OpKernelInfo& op_kernel_info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ bool use_swish_activation_;
+ float epsilon_;
+ int num_groups_;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu
new file mode 100644
index 0000000000000..01ba078b4be77
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu
@@ -0,0 +1,475 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5
+#include
+#include
+#include
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cuda/diffusion/group_norm_impl.h"
+#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+static inline int32_t divUp(int32_t m, int32_t n) {
+ return (m + n - 1) / n;
+}
+
+static inline __device__ __host__ float sigmoid(float x) {
+ return 1.F / (1.F + expf(-x));
+}
+
+struct GroupSums {
+ // Is it the 1st element of the group?
+ int32_t flag;
+ // The sum.
+ float sum;
+ // The sum of squares.
+ float sumSq;
+};
+
+struct GroupSumsOp {
+ inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) {
+ GroupSums dst;
+ dst.sum = b.flag ? b.sum : (a.sum + b.sum);
+ dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
+ dst.flag = a.flag + b.flag;
+ return dst;
+ }
+};
+
+template
+struct GroupNormNHWCParams {
+ // The output buffer. Layout NHWC.
+ T* dst;
+ // The input buffer. Layout NHWC.
+ T const* src;
+ // The gamma scaling factor.
+ float const* gamma;
+ // The beta term to add in GN.
+ float const* beta;
+ // The temporary buffer to do the global parallel reduction. Size:
+ // BLOCKS_PER_BATCH x C x 2.
+ float* redBuffer;
+
+ // The number of instances in the batch.
+ int32_t n;
+ // The height and width of each activation map.
+ int32_t h;
+ int32_t w;
+ // The number of channels.
+ int32_t c;
+ // The number of groups.
+ int32_t groups;
+ // Do we apply the Swish activation function?
+ bool withSwish;
+
+ // Precomputed values and parameters to control the execution of the kernels.
+
+ // The number of activations per instance (h * w) and the number of
+ // activations per block.
+ int32_t hw;
+ int32_t hwPerBlock;
+ // The number of channels per group and blocks per activation in the C
+ // dimension.
+ int32_t cPerBlock;
+ int32_t cPerGroup;
+
+ // The precomputed stride between instances.
+ int32_t hwc;
+ // The inverse of hwc in floats (to compute mean/var).
+ float invHWC;
+ // The precomputed number of groups per block.
+ int32_t groupsPerBlock;
+};
+
+template
+inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq);
+
+template <>
+inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) {
+ // Fetch two channels per thread.
+ __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
+
+ float2 f2 = __half22float2(h2);
+
+ // Update the sum.
+ sum += f2.x + f2.y;
+
+ // Update the sum of squares.
+ sumSq += f2.x * f2.x + f2.y * f2.y;
+}
+
+template <>
+inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) {
+ // Fetch two channels per thread.
+ float2 f2 = *reinterpret_cast(&src[offset]);
+
+ // Update the sum.
+ sum += f2.x + f2.y;
+
+ // Update the sum of squares.
+ sumSq += f2.x * f2.x + f2.y * f2.y;
+}
+
+template
+__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) {
+ // The object in charge of doing the sums for the different blocks.
+ typedef cub::BlockScan BlockScan;
+
+ // Allocate shared memory for BlockScan.
+ __shared__ typename BlockScan::TempStorage tempStorage;
+ // Allocate shared memory for the groups. We could reduce the amount of shared
+ // memory reserved.
+ __shared__ float2 smem[tTHREADS_PER_BLOCK];
+
+ // The instance in the batch.
+ int32_t ni = blockIdx.z;
+ // The channel loaded by that thread (2 channels per thread for F16x2).
+ int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
+
+ // The first activation loaded by that block.
+ int32_t hwBegin = blockIdx.y * params.hwPerBlock;
+ // The last activation loaded by that block.
+ int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
+
+ // The sums.
+ float sum = 0.F;
+ float sumSq = 0.F;
+
+ // Iterate over the activations to compute the sums.
+ if (ci < params.c) {
+ for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
+ // The offset.
+ int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci;
+ UpdateSum(params.src, offset, sum, sumSq);
+ }
+ }
+
+ // The group that thread works on and the channel in the group (modulus).
+ int32_t gi = threadIdx.x * 2 / params.cPerGroup;
+ int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;
+
+ // The data for the summations.
+ GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};
+
+ // Do the segmented scan.
+ GroupSums out;
+ BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());
+
+ // Store the results for the groups in shared memory (to produce coalesced
+ // stores later).
+ if (cj == params.cPerGroup - 2) { //2 channels per thread
+ smem[gi] = make_float2(out.sum, out.sumSq);
+ }
+
+ // Make sure the data is in shared memory.
+ __syncthreads();
+
+ // The global group index.
+ int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;
+
+ // Threads that have nothing left to do, exit.
+ if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
+ return;
+ }
+
+ // The first threads (those storing to global memory, load the values).
+ float2 sums = smem[threadIdx.x];
+
+ // Store to global memory.
+ atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
+ atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
+}
+
+template
+void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) {
+ // Make sure the values are as we expect.
+ ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0);
+ // Make sure a group does not span multiple blocks.
+ ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0);
+
+ dim3 grid;
+
+ // The number of blocks to compute all the channels.
+ grid.x = params.c / params.cPerBlock;
+ // The number of blocks to compute all the activations in a given instance.
+ grid.y = divUp(params.hw, params.hwPerBlock);
+ // The number of instances.
+ grid.z = params.n;
+
+ switch (params.cPerBlock) {
+ case 320:
+ groupNormNHWCSumKernel<<>>(params);
+ break;
+ case 480:
+ groupNormNHWCSumKernel<<>>(params);
+ break;
+ case 256:
+ groupNormNHWCSumKernel<<>>(params);
+ break;
+ case 128:
+ groupNormNHWCSumKernel<<>>(params);
+ break;
+ default:
+ ORT_NOT_IMPLEMENTED("Not implemented");
+ }
+}
+
+template
+__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish);
+
+template <>
+__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev,
+ float2& gammaF2, float2& betaF2, bool swish) {
+ // Fetch two channels per thread.
+ __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
+
+ // Extract the two half values.
+ float2 f2 = __half22float2(h2);
+
+ // Normalize the channels.
+ f2.x = (f2.x - mean) * invStdDev;
+ f2.y = (f2.y - mean) * invStdDev;
+
+ // Scale by gamma and add beta.
+ f2.x = gammaF2.x * f2.x + betaF2.x;
+ f2.y = gammaF2.y * f2.y + betaF2.y;
+
+ // Apply Swish if needed.
+ if (swish) {
+ f2.x = f2.x * sigmoid(f2.x);
+ f2.y = f2.y * sigmoid(f2.y);
+ }
+
+ *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2);
+}
+
+template <>
+__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev,
+ float2& gammaF2, float2& betaF2, bool swish) {
+ // Fetch two channels per thread.
+ float2 f2 = *reinterpret_cast(&src[offset]);
+
+ // Normalize the channels.
+ f2.x = (f2.x - mean) * invStdDev;
+ f2.y = (f2.y - mean) * invStdDev;
+
+ // Scale by gamma and add beta.
+ f2.x = gammaF2.x * f2.x + betaF2.x;
+ f2.y = gammaF2.y * f2.y + betaF2.y;
+
+ // Apply Swish if needed.
+ if (swish) {
+ f2.x = f2.x * sigmoid(f2.x);
+ f2.y = f2.y * sigmoid(f2.y);
+ }
+
+ *reinterpret_cast(&dst[offset]) = f2;
+}
+
+template
+__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) {
+ // The channel loaded by that thread (2 channels per thread for F16x2).
+ int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
+ if (ci >= params.c) {
+ return;
+ }
+
+ // The instance in the batch.
+ int32_t ni = blockIdx.z;
+
+ // The group that thread works on and the channel in the group (modulus).
+ int32_t gi = ci / params.cPerGroup;
+
+ // Load the sum and sum of squares for the group.
+ float sum = 0.F, sumSq = 0.F;
+ if (gi < params.groups) {
+ sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
+ sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
+ }
+
+ // Load gamma/beta.
+ float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]);
+ float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]);
+
+ // Compute the mean.
+ float mean = sum * params.invHWC;
+ // Compute the variance.
+ float var = sumSq * params.invHWC - (mean * mean);
+ // Compute the inverse of the stddev.
+ float invStdDev = var <= 0.F ? 1.F : rsqrtf(var);
+
+ // The first activation loaded by that block.
+ int32_t hwBegin = blockIdx.y * params.hwPerBlock;
+ // The last activation loaded by that block.
+ int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);
+
+ // Iterate over the activations to compute the sums.
+ for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
+ // The src/dst offset.
+ int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;
+
+ // Fetch two channels per thread.
+ computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish);
+ }
+}
+
+template
+void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) {
+ // Make sure the dimensions are aligned with what we expect.
+ ORT_ENFORCE(params.c % params.cPerBlock == 0);
+ // Make sure a group does not span multiple blocks.
+ ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0);
+
+ dim3 grid;
+
+ // The number of blocks to compute all the channels.
+ grid.x = params.c / params.cPerBlock;
+ // The number of blocks to compute all the activations in a given instance.
+ grid.y = divUp(params.hw, params.hwPerBlock);
+ // The number of instances.
+ grid.z = params.n;
+
+ switch (params.cPerBlock) {
+ case 320:
+ groupNormNHWCScaleKernel<<>>(params);
+ break;
+ case 480:
+ groupNormNHWCScaleKernel<<>>(params);
+ break;
+ case 256:
+ groupNormNHWCScaleKernel<<>>(params);
+ break;
+ case 128:
+ groupNormNHWCScaleKernel<<>>(params);
+ break;
+ default:
+ ORT_NOT_IMPLEMENTED("Not implemented");
+ }
+}
+
+int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) {
+ int32_t maxDivisor = -1;
+ for (int32_t i = 1; i <= std::sqrt(n); i++) {
+ if (n % i == 0) {
+ int32_t divisor1 = n / i;
+ int32_t divisor2 = i;
+
+ if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) {
+ maxDivisor = divisor1;
+ }
+ if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) {
+ maxDivisor = divisor2;
+ }
+ }
+ }
+ return maxDivisor;
+}
+
+template
+Status LaunchGroupNormKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* input,
+ const float* gamma,
+ const float* beta,
+ void* workspace,
+ float epsilon,
+ int batch_size,
+ int num_channels,
+ int height,
+ int width,
+ int num_groups,
+ bool use_swish_activation) {
+ if (batch_size > static_cast(kMaxGroupNormBatchSize)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
+ "only support batch_size <= 32. Got", batch_size);
+ }
+
+ if (num_groups != static_cast(kGroupNormNumberOfGroups)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
+ "only num_groups=32 is supported. Got", num_groups);
+ }
+
+ GroupNormNHWCParams params;
+ int32_t cPerBlock = 320;
+ int32_t maxBlocksPerHW = 1024;
+ switch (num_channels) {
+ case 960:
+ case 1920:
+ cPerBlock = 480;
+ break;
+ case 512:
+ case 256:
+ cPerBlock = 256;
+ break;
+ case 128:
+ cPerBlock = 128;
+ break;
+ default:
+ cPerBlock = 320;
+ }
+
+ params.withSwish = use_swish_activation;
+ params.dst = output;
+ params.src = input;
+ params.gamma = gamma;
+ params.beta = beta;
+ params.redBuffer = reinterpret_cast(workspace);
+ params.n = batch_size;
+ params.h = height;
+ params.w = width;
+ params.c = num_channels;
+ params.groups = num_groups;
+ params.hw = params.h * params.w;
+ const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW);
+ params.hwPerBlock = divUp(params.hw, blocksPerHW);
+ params.cPerBlock = cPerBlock;
+ params.cPerGroup = params.c / params.groups;
+ params.hwc = params.hw * params.c;
+ params.invHWC = 1.F / (float)(params.hw * params.cPerGroup);
+ params.groupsPerBlock = cPerBlock / params.cPerGroup;
+
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR("input", input, batch_size, num_channels, height * width);
+ DUMP_TENSOR("gamma", gamma, 1, num_channels);
+ DUMP_TENSOR("beta", beta, 1, num_channels);
+ cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream);
+ groupNormNHWCSum(params, stream);
+ DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2);
+ CUDA_RETURN_IF_ERROR(cudaGetLastError());
+ groupNormNHWCScale(params, stream);
+ CUDA_RETURN_IF_ERROR(cudaGetLastError());
+ DUMP_TENSOR("output", output, batch_size, num_channels, height * width);
+ return Status::OK();
+}
+
+template Status LaunchGroupNormKernel(cudaStream_t stream, half* output,
+ const half* input, const float* gamma, const float* beta, void* workspace,
+ float epsilon, int batch_size, int num_channels,
+ int height, int width, int num_groups, bool swish);
+
+template Status LaunchGroupNormKernel(cudaStream_t stream, float* output,
+ const float* input, const float* gamma, const float* beta, void* workspace,
+ float epsilon, int batch_size, int num_channels,
+ int height, int width, int num_groups, bool swish);
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h
new file mode 100644
index 0000000000000..c7e9245050ee6
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/common/status.h"
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+constexpr size_t kMaxGroupNormBatchSize = 32;
+constexpr size_t kGroupNormNumberOfGroups = 32;
+
+constexpr size_t GetGroupNormWorkspaceSizeInBytes() {
+ // Two buffers for sum and squared sum
+ return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups;
+}
+
+template
+Status LaunchGroupNormKernel(
+ cudaStream_t stream,
+ T* output, // normalized output tensor
+ const T* input, // input tensor
+ const float* gamma, // gamma (also known as weight or scale)
+ const float* beta, // beta (also known as bias)
+ void* workspace, // Work space
+ float epsilon, // epsilon used normalization
+ int batch_size, // N
+ int num_channels, // C
+ int height, // H
+ int width, // W
+ int num_groups, // number of groups
+ bool use_swish_activation // Whether there is Swish activation after group normalization
+);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc
new file mode 100644
index 0000000000000..79f0a18ba515f
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/span_utils.h"
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "core/providers/cuda/tensor/slice.h"
+#include "core/providers/cuda/nn/conv.h"
+
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ NhwcConv, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ Conv);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc
index 39c3bb282d912..48881ddca4063 100644
--- a/onnxruntime/contrib_ops/cuda/fused_conv.cc
+++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc
@@ -9,10 +9,10 @@ namespace contrib {
namespace cuda {
template
-class FusedConv : public onnxruntime::cuda::Conv {
+class FusedConv : public onnxruntime::cuda::Conv {
public:
- using Base = onnxruntime::cuda::Conv;
- FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) {
+ using Base = onnxruntime::cuda::Conv;
+ FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) {
std::string activation;
if (info.GetAttr("activation", &activation) == Status::OK() &&
MapMode(activation) == Status::OK() &&
diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
index 6c0f7f69c58a1..741f9ac259da1 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
@@ -11,7 +11,7 @@ namespace contrib {
namespace cuda {
namespace transformers {
-#ifdef DEBUG_GENERATION
+#if DUMP_TENSOR_LEVEL > 0
template
class PinnedHostBuffer {
public:
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index b4ad4d64e7ddb..68e3985651123 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -127,32 +127,41 @@ void RestorePaddingTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx)
void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
// Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
- // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
- // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
+ // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, kv_sequence_length, num_heads, 2, head_size)
+ // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or nullptr
// Output 0 has shape (batch_size, sequence_length, v_hidden_size)
// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
- if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) {
+ if (hasInputShape(ctx, 0)) {
auto& query_shape = getInputShape(ctx, 0);
auto& query_dims = query_shape.dim();
if (query_dims.size() != 3) {
fail_shape_inference("Inputs 0 (query) shall be 3 dimensions");
}
- auto& value_shape = getInputShape(ctx, 2);
- auto& value_dims = value_shape.dim();
- if (value_dims.size() != 3) {
- fail_shape_inference("Inputs 2 (value) shall be 3 dimensions");
+ if (hasInputShape(ctx, 2)) {
+ auto& value_shape = getInputShape(ctx, 2);
+ auto& value_dims = value_shape.dim();
+ if (value_dims.size() != 3) {
+ fail_shape_inference("Inputs 2 (value) shall be 3 dimensions");
+ }
+
+ ONNX_NAMESPACE::TensorShapeProto output_shape;
+ *output_shape.add_dim() = query_dims[0];
+ *output_shape.add_dim() = query_dims[1];
+ *output_shape.add_dim() = value_dims[2];
+ updateOutputShape(ctx, 0, output_shape);
}
- ONNX_NAMESPACE::TensorShapeProto output_shape;
- *output_shape.add_dim() = query_dims[0];
- *output_shape.add_dim() = query_dims[1];
- *output_shape.add_dim() = value_dims[2];
- updateOutputShape(ctx, 0, output_shape);
+ if (hasInputShape(ctx, 1)) {
+ auto& key_shape = getInputShape(ctx, 1);
+ if (key_shape.dim().size() == 5) {
+ ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx);
+ }
+ }
}
}
@@ -287,16 +296,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"T")
.Input(1,
"key",
- "Key with shape (batch_size, kv_sequence_length, hidden_size)",
+ "Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)",
"T")
.Input(2,
"value",
"Value with shape (batch_size, kv_sequence_length, v_hidden_size)",
- "T")
+ "T",
+ OpSchema::Optional)
.Input(3,
"bias",
"Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection",
- "T")
+ "T",
+ OpSchema::Optional)
.Input(4,
"key_padding_mask",
"Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)",
diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc
new file mode 100644
index 0000000000000..14a267357371d
--- /dev/null
+++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc
@@ -0,0 +1,115 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/graph/constants.h"
+#include "core/graph/contrib_ops/contrib_defs.h"
+#include "core/graph/contrib_ops/onnx_function_util.h"
+#include "core/graph/contrib_ops/shape_inference_functions.h"
+
+// Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from
+// ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build
+#if defined(_WIN32) && !defined(NDEBUG)
+#pragma warning(disable : 26426)
+#endif
+
+namespace onnxruntime {
+namespace contrib {
+using ONNX_NAMESPACE::AttributeProto;
+using ONNX_NAMESPACE::OpSchema;
+using ONNX_NAMESPACE::TensorShapeProto;
+#ifndef NDEBUG
+using ONNX_NAMESPACE::DbgOperatorSetTracker;
+#endif
+
+constexpr const char* GroupNorm_ver1_doc = R"DOC(
+Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494).
+
+This operator transforms input according to
+ y = gamma * (x - mean) / sqrt(variance + epsilon) + beta
+
+The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group.
+The weight and bias are per-channel affine transform parameter vectors of size num_channels.
+
+The activation attribute can be used to enable activation after group normalization.
+)DOC";
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ GroupNorm, 1,
+ OpSchema()
+ .SetDoc(GroupNorm_ver1_doc)
+ .Attr("epsilon", "The epsilon value to use to avoid division by zero", AttributeProto::FLOAT, static_cast(1e-5))
+ .Attr("groups",
+ "The number of groups of channels. It should be a divisor of the number of channels C",
+ AttributeProto::INT)
+ .Attr("activation",
+ "Activation after group normalization: 0 for None, 1 for Swish",
+ AttributeProto::INT)
+ .Input(0,
+ "X",
+ "Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data",
+ "T")
+ .Input(1,
+ "gamma",
+ "1D gamma tensor for normalization with shape (C), where C is number of channels",
+ "M")
+ .Input(2,
+ "beta",
+ "1D beta tensor for normalization with shape (C), where C is number of channels",
+ "M")
+ .Output(0,
+ "Y",
+ "The output tensor of the same shape as X",
+ "T")
+ .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.")
+ .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.")
+ .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
+
+constexpr const char* BiasSplitGelu_ver1_doc = R"DOC(
+A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left
+tensor multiplies the Gelu activation result of right tensor.
+)DOC";
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ BiasSplitGelu, 1,
+ OpSchema()
+ .SetDoc(BiasSplitGelu_ver1_doc)
+ .Input(0,
+ "X",
+ "Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension",
+ "T")
+ .Input(1,
+ "bias",
+ "Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor",
+ "T")
+ .Output(0,
+ "Y",
+ "The output tensor with dimensions (N, S, D/2)",
+ "T")
+ .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) {
+ auto& input_shape = getInputShape(ctx, 0);
+ if (input_shape.dim().size() != 3) {
+ fail_shape_inference("input shall be 3 dimensions");
+ }
+
+ auto& bias_shape = getInputShape(ctx, 1);
+ if (bias_shape.dim().size() != 1) {
+ fail_shape_inference("bias shall be 1 dimension");
+ }
+
+ TensorShapeProto output_shape;
+ *output_shape.add_dim() = input_shape.dim(0);
+ *output_shape.add_dim() = input_shape.dim(1);
+ if (bias_shape.dim(0).has_dim_value()) {
+ output_shape.add_dim()->set_dim_value(bias_shape.dim(0).dim_value() / 2);
+ } else {
+ output_shape.add_dim();
+ }
+
+ updateOutputShape(ctx, 0, output_shape);
+ }
+ }));
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index 1f0af31a4bdd0..a511d01fe1624 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -49,6 +49,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BeamSearch);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasDropout);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BitmaskBiasDropout);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasGelu);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSplitGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSoftmax);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BifurcationDetector);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CDist);
@@ -69,6 +70,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupNorm);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite);
@@ -135,6 +137,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema