diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index d3b8f5ebfcc26..85246ec8bd37c 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -27,6 +27,15 @@ set(contrib_ops_excluded_files "bert/tensorrt_fused_multihead_attention/*" "bert/transformer_common.h" "bert/transformer_common.cc" + "diffusion/group_norm.h" + "diffusion/group_norm.cc" + "diffusion/group_norm_impl.cu" + "diffusion/group_norm_impl.h" + "diffusion/bias_split_gelu_impl.h" + "diffusion/bias_split_gelu_impl.cu" + "diffusion/bias_split_gelu.h" + "diffusion/bias_split_gelu.cc" + "diffusion/nhwc_conv.cc" "math/complex_mul.cc" "math/complex_mul.h" "math/complex_mul_impl.cu" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1e6d46963cd21..8cd6d4c9e26f1 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -9,6 +9,7 @@ Do not modify directly.* * com.microsoft.BiasDropout * com.microsoft.BiasGelu * com.microsoft.BiasSoftmax + * com.microsoft.BiasSplitGelu * com.microsoft.BifurcationDetector * com.microsoft.BitmaskBiasDropout * com.microsoft.BitmaskDropout @@ -34,6 +35,7 @@ Do not modify directly.* * com.microsoft.GemmFastGelu * com.microsoft.GreedySearch * com.microsoft.GridSample + * com.microsoft.GroupNorm * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention @@ -590,6 +592,39 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.BiasSplitGelu** + + A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left + tensor multiplies the Gelu activation result of right tensor. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
X : T
+
Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension
+
bias : T
+
Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor
+
+ +#### Outputs + +
+
Y : T
+
The output tensor with dimensions (N, S, D/2)
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X and output Y types to float tensors.
+
+ + ### **com.microsoft.BifurcationDetector** Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens, @@ -1811,6 +1846,61 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GroupNorm** + + Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494). + + This operator transforms input according to + y = gamma * (x - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for Swish
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
+ +#### Outputs + +
+
Y : T
+
The output tensor of the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X and output Y types to float tensors.
+
M : tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.Inverse** #### Version @@ -2132,16 +2222,16 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (4 - 5) +#### Inputs (2 - 5)
query : T
Query with shape (batch_size, sequence_length, hidden_size)
key : T
-
Key with shape (batch_size, kv_sequence_length, hidden_size)
-
value : T
+
Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, v_hidden_size)
-
bias : T
+
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ad571dacb20d7..7e4eb38be780b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -175,7 +175,8 @@ Do not modify directly.* |||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |LpNormalization|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float)| -|LpPool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float)| +|LpPool|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(float)| +|||[11, 17]|**T** = tensor(float)| |||[2, 10]|**T** = tensor(float)| |MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| @@ -789,6 +790,7 @@ Do not modify directly.* |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BitmaskBiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)| |BitmaskDropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)| |ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -804,11 +806,13 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| +|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| +|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* extra_add:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| @@ -1086,7 +1090,8 @@ Do not modify directly.* |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index ce109a83720b9..8c3af05972c95 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -21,11 +21,15 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, int max_threads_per_block) { - // query (Q) : (B, S, D) - // key (K) : (B, L, D) - // value (V) : (B, L, D_v) - // bias (Q/K/V) : (D + D + D_v) - // key_padding_mask (K/V) : (B, L) or (L) + // query (Q) : (B, S, D) + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // bias (Q/K/V) : (D + D + D_v) + // key_padding_mask (K/V) : (B) or (B, L) or None + // When packed kv is used: + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None const auto& query_dims = query->Shape().GetDims(); if (query_dims.size() != 3) { @@ -34,15 +38,50 @@ Status CheckInputs(const T* query, } const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + if (key_dims.size() != 3 && key_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ", key_dims.size()); } + if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ", - bias_dims.size()); + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int hidden_size = static_cast(query_dims[2]); + int head_size = static_cast(hidden_size) / num_heads; + int kv_sequence_length = static_cast(key_dims[1]); + + if (key_dims.size() == 3) { + if (key_dims[2] != query_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + } + } else // if (key_dims.size() == 5) + { + if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); + } + } + + if (bias != nullptr) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ", + bias_dims.size()); + } + + // Currently, bias is not allowed for packed KV. This constraint can be removed later. + // Here we assume that fusion tool will not include bias for packed KV. + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. "); + } } AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; @@ -61,47 +100,39 @@ Status CheckInputs(const T* query, } } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - int64_t batch_size = query_dims[0]; - int64_t sequence_length = query_dims[1]; - int64_t kv_sequence_length = key_dims[1]; - int64_t q_hidden_size = query_dims[2]; - int64_t v_hidden_size = 0; - - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } + int v_hidden_size = hidden_size; + if (value != nullptr) { + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } - if (key_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have same same dim 1 (sequence_length)"); + if (key_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)"); + } + v_hidden_size = static_cast(value_dims[2]); } - v_hidden_size = value_dims[2]; if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); - output_parameters->batch_size = static_cast(batch_size); - output_parameters->sequence_length = static_cast(sequence_length); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; output_parameters->past_sequence_length = 0; - output_parameters->kv_sequence_length = static_cast(kv_sequence_length); - output_parameters->total_sequence_length = static_cast(kv_sequence_length); + output_parameters->kv_sequence_length = kv_sequence_length; + output_parameters->total_sequence_length = kv_sequence_length; output_parameters->max_sequence_length = 0; output_parameters->input_hidden_size = 0; - output_parameters->hidden_size = static_cast(q_hidden_size); - output_parameters->v_hidden_size = static_cast(v_hidden_size); - output_parameters->head_size = static_cast(q_hidden_size) / num_heads; - output_parameters->v_head_size = static_cast(v_hidden_size) / num_heads; + output_parameters->hidden_size = hidden_size; + output_parameters->v_hidden_size = v_hidden_size; + output_parameters->head_size = hidden_size / num_heads; + output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; output_parameters->is_unidirectional = false; output_parameters->past_present_share_buffer = false; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cf1d99688546a..630c533c47323 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -9,10 +9,6 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -#ifndef NDEBUG -//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -#endif - namespace onnxruntime { namespace concurrency { @@ -57,14 +53,14 @@ struct IBeamSearchCpuState { template struct IGreedySearchState { - gsl::span sequences_space; // shape (2, batch_size, max_length) - gsl::span sequence_lengths; // shape (batch_size) - gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. - gsl::span eos_meet; // shape (batch_size) - gsl::span next_token_scores; // shape (batch_size, vocab_size) - gsl::span next_tokens; // shape (batch_size) - gsl::span temp_topk_scores_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1 (GPU only) - gsl::span temp_topk_tokens_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1(GPU only) + gsl::span sequences_space; // shape (2, batch_size, max_length) + gsl::span sequence_lengths; // shape (batch_size) + gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. + gsl::span eos_meet; // shape (batch_size) + gsl::span next_token_scores; // shape (batch_size, vocab_size) + gsl::span next_tokens; // shape (batch_size) + gsl::span temp_topk_scores_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1 (GPU only) + gsl::span temp_topk_tokens_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1(GPU only) gsl::span topk_scores_buffer; // shape (batch_size), output buffer for topk stage 2 (GPU only) gsl::span topk_tokens_buffer; // shape (batch_size), output buffer for topk stage 2 (GPU only) }; @@ -167,6 +163,26 @@ struct IGenerationParameters { bool custom_sampling = false; }; +// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) +#ifdef DEBUG_GENERATION +#define DUMP_TENSOR_LEVEL 2 +#else +#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation. +#endif + +#if DUMP_TENSOR_LEVEL > 0 +#define DUMP_TENSOR_INIT() transformers::CudaTensorConsoleDumper dumper +#define DUMP_TENSOR(...) dumper.Print(__VA_ARGS__) +#else +#define DUMP_TENSOR_INIT() +#define DUMP_TENSOR(...) +#endif +#if DUMP_TENSOR_LEVEL > 1 +#define DUMP_TENSOR_D(...) dumper.Print(__VA_ARGS__) +#else +#define DUMP_TENSOR_D(...) +#endif + class IConsoleDumper { public: IConsoleDumper() : is_enabled_(true) {} diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index b7eebb9d48785..e86736726c224 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -366,6 +366,39 @@ __global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* outp } } +template +__global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output) { + // Format 4 to unpack TRT packed input format for memory efficient attention. + // Input: BxSxNxMxH + // Output: MxBxSxNxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * head_size; + const int NHS = NH * sequence_length; + + int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M; + const int out_offset = n * head_size + s * NH + b * NHS + m * NHS * batch_size; + + const int h = threadIdx.x; + if (h < head_size) { + if (biases != nullptr) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + } else { + output[out_offset + h] = input[in_offset + h]; + } + } +} + template __global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) { // Format 3 for cutlass memory efficient attention @@ -481,7 +514,6 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const } } - template void InvokeAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, @@ -506,7 +538,9 @@ void InvokeAddBiasTranspose( ORT_ENFORCE(total_matrix_count == 3); AddBiasTransposeCutlass<<>>(input, biases, output, v_head_size); } - } else { // format == 0 + } else if (format == 4) { // format == 4 + AddBiasUnpack<<>>(total_matrix_count, input, biases, output); + } else { // format == 0 AddBiasTranspose<<>>(input, biases, output); } } else { @@ -528,6 +562,8 @@ void InvokeAddBiasTranspose( } else { ORT_THROW("AddBiasTranspose (format 3) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size"); } + } else if (format == 4) { // format == 4 + ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block"); } else { // format 0 AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } @@ -551,7 +587,7 @@ void LaunchAddBiasTranspose( InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, H, input2, biases2, output2, qkv_add_bias2, H_v, total_matrix_count); - } else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) { + } else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) { const int H = qk_head_size / 2; const int H_v = v_head_size / 2; const half2* input2 = reinterpret_cast(input); @@ -610,7 +646,6 @@ void InvokeAddBiasTransposeTrt( const int batch_size, const int sequence_length, const int num_heads, const int head_size, const T* biases, const T* query, const T* key, const T* value, T* output, bool is_cross_attention, int kv_sequence_length) { - if (!is_cross_attention) { ORT_ENFORCE(sequence_length == kv_sequence_length); constexpr int num_matrices = 3; @@ -696,52 +731,51 @@ void LaunchAddBiasTransposeTrt( } } - template void InvokeAddBias( cudaStream_t stream, const int max_threads_per_block, const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v) { - constexpr int num_matrices = 1; - // Q - { - const dim3 grid(sequence_length, batch_size, num_matrices); - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(query, biases, q); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); - } + constexpr int num_matrices = 1; + // Q + { + const dim3 grid(sequence_length, batch_size, num_matrices); + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(query, biases, q); + } else { + const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); } - // K - { - const dim3 grid(kv_sequence_length, batch_size, num_matrices); - const T* biases_k = biases + num_heads * head_size; + } + // K + { + const dim3 grid(kv_sequence_length, batch_size, num_matrices); + const T* biases_k = biases + num_heads * head_size; - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(key, biases_k, k); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k); - } + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(key, biases_k, k); + } else { + const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k); } + } - // V - { - const dim3 grid(kv_sequence_length, batch_size, num_matrices); + // V + { + const dim3 grid(kv_sequence_length, batch_size, num_matrices); - const T* biases_v = biases + 2 * num_heads * head_size; - if (v_head_size * num_heads <= max_threads_per_block) { - const dim3 block(v_head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(value, biases_v, v); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v); - } + const T* biases_v = biases + 2 * num_heads * head_size; + if (v_head_size * num_heads <= max_threads_per_block) { + const dim3 block(v_head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(value, biases_v, v); + } else { + const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v); } + } } template <> @@ -750,7 +784,7 @@ void LaunchAddBias( const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const float* biases, const float* query, const float* key, const float* value, float* q, float* k, float* v) { -if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { + if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { const int H = head_size / 4; const int H_v = v_head_size / 4; const float4* query2 = reinterpret_cast(query); @@ -761,8 +795,8 @@ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { float4* k2 = reinterpret_cast(k); float4* v2 = reinterpret_cast(v); InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, - biases2, query2, key2, value2, q2, k2, v2); + batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, + biases2, query2, key2, value2, q2, k2, v2); } else if (0 == (head_size & 1) && 0 == (v_head_size & 1)) { const int H = head_size / 2; const int H_v = v_head_size / 2; @@ -774,14 +808,13 @@ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { float2* k2 = reinterpret_cast(k); float2* v2 = reinterpret_cast(v); InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, - biases2, query2, key2, value2, q2, k2, v2); + batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, + biases2, query2, key2, value2, q2, k2, v2); } else { InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, head_size, v_head_size, - biases, query, key, value, q, k, v); + batch_size, sequence_length, kv_sequence_length, num_heads, head_size, v_head_size, + biases, query, key, value, q, k, v); } - } template <> @@ -790,8 +823,7 @@ void LaunchAddBias( const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const half* biases, const half* query, const half* key, const half* value, half* q, half* k, half* v) { - - if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { + if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { const int H = head_size / 4; const int H_v = v_head_size / 4; const Half4* query2 = reinterpret_cast(query); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index 8cc36637054e7..a2c3265284a4d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -24,6 +24,10 @@ namespace cuda { // format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (batch_size, sequence_length, num_matrices, num_heads, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) +// format 4: (requires qk_head_size = v_head_size) +// input: (batch_size, sequence_length, num_heads, num_matrices, head_size) +// output: (num_matrices, batch_size, sequence_length, num_heads, head_size) + template void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 187f1bb37edc5..8c7ef9f919519 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -48,21 +48,6 @@ using namespace cub; #define CHECK_CUDA(expr) CUDA_RETURN_IF_ERROR(expr) #define CUDA_MEMORY_ALIGNMENT 256 -#define DUMP_ATTENTION_LEVEL 0 -#if DUMP_ATTENTION_LEVEL > 1 -#define DUMP_ATTENTION_INIT() transformers::CudaTensorConsoleDumper dumper -#define DUMP_ATTENTION(...) dumper.Print(__VA_ARGS__) -#define DUMP_ATTENTION_D(...) dumper.Print(__VA_ARGS__) -#elif DUMP_ATTENTION_LEVEL > 0 -#define DUMP_ATTENTION_INIT() transformers::CudaTensorConsoleDumper dumper -#define DUMP_ATTENTION(...) dumper.Print(__VA_ARGS__) -#define DUMP_ATTENTION_D(...) -#else -#define DUMP_ATTENTION_INIT() -#define DUMP_ATTENTION(...) -#define DUMP_ATTENTION_D(...) -#endif - namespace onnxruntime { namespace contrib { namespace cuda { @@ -283,7 +268,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, // Default format for memory efficient attention. // When there is past state, the format shal be BxNxSxH, so we disable memory efficient attention when there is past. - DUMP_ATTENTION_INIT(); + DUMP_TENSOR_INIT(); if (nullptr != data.gemm_buffer) { if (data.bias == nullptr) { // For quantized attention, bias has been added so only need transpose here. @@ -317,15 +302,42 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, 3); } - } else { // gemm_buffer == nullptr + } else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv + // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. + // CheckInputs verified this constraint. + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + + if (use_memory_efficient_attention) { + // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, k, + true, v_head_size, qkv_add_bias, 2); + DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (data.fused_cross_attention_kernel == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } + } else { // gemm_buffer == nullptr and not packed kv assert(data.query != nullptr && data.key != nullptr && data.value != nullptr && data.bias != nullptr); - DUMP_ATTENTION_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_ATTENTION_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_ATTENTION_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); - DUMP_ATTENTION_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); if (data.fused_cross_attention_kernel != nullptr) { assert(qk_head_size == v_head_size); @@ -347,9 +359,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, num_heads, qk_head_size, v_head_size, data.bias, data.query, data.key, data.value, q, k, v); - DUMP_ATTENTION_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } #endif @@ -362,7 +374,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, batch_size, sequence_length, num_heads, qk_head_size, data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_ATTENTION_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); + DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); qkv_format = AttentionQkvFormat::QKV_BSN3H; } else { // unfused kernel @@ -387,9 +399,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.value, data.bias + 2 * num_heads * qk_head_size, v, true, -1); - DUMP_ATTENTION_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_ATTENTION_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size); - DUMP_ATTENTION_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size); + DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size); qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } } @@ -419,8 +431,7 @@ Status QkvToContext( void* fused_runner = data.fused_runner; // At most one fused kernel is enabled. - assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr) <= 1); + assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); const int batches = batch_size * num_heads; const int size_per_batch_q = sequence_length * qk_head_size; @@ -481,7 +492,7 @@ Status QkvToContext( ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, - use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer + use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer data.gemm_buffer, data.present)); present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; @@ -491,7 +502,7 @@ Status QkvToContext( } // Q, K and V are ready now - DUMP_ATTENTION_INIT(); + DUMP_TENSOR_INIT(); if (data.fused_cross_attention_kernel != nullptr) { assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); @@ -499,7 +510,7 @@ Status QkvToContext( LaunchTrtSequenceOffset(q_sequence_offset, nullptr, batch_size, sequence_length, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_ATTENTION_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. @@ -509,26 +520,34 @@ Status QkvToContext( LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, kv_sequence_length, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_ATTENTION_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); + DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); + // When there is no bias, we can directly use q and packed kv from inputs. TODO: not need qkv in workspace. + void const* query = q; + void const* packed_kv = k; + if (data.value == nullptr && data.bias == nullptr) { + query = data.query; + packed_kv = data.key; + } + run_fused_cross_attention( - q, // Q - k, // packed KV - q_sequence_offset, // cumulated sequence length of Q - kv_sequence_offset, // cumulated sequence length of KV - data.output, // output - cross_attention_kernel, // kernels - batch_size, // batch size - num_heads, // number of heads - qk_head_size, // head size of Q/K/V - sequence_length, // sequence length of Q - kv_sequence_length, // sequence length of KV + query, // Q + packed_kv, // packed KV + q_sequence_offset, // cumulated sequence length of Q + kv_sequence_offset, // cumulated sequence length of KV + data.output, // output + cross_attention_kernel, // kernels + batch_size, // batch size + num_heads, // number of heads + qk_head_size, // head size of Q/K/V + sequence_length, // sequence length of Q + kv_sequence_length, // sequence length of KV stream); - DUMP_ATTENTION("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); return Status::OK(); } @@ -554,11 +573,11 @@ Status QkvToContext( if (use_fused_kernel) { assert(qkv_format == AttentionQkvFormat::QKV_BSN3H); fused_fp16_runner->run(qkv, sequence_offset, data.output, stream); - DUMP_ATTENTION("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); } else { assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_ATTENTION("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size); } return Status::OK(); } @@ -570,6 +589,13 @@ Status QkvToContext( assert(data.mask_index == nullptr); assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + const void* query = q; + const void* key = k; + const void* value = v; + if (data.gemm_buffer == nullptr && data.value == nullptr) { // packed KV + query = data.query; + } + MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; @@ -582,15 +608,15 @@ Status QkvToContext( p.causal = parameters.is_unidirectional; p.cu_seqlens_q = nullptr; p.cu_seqlens_k = nullptr; - p.query = q; - p.key = k; - p.value = v; + p.query = query; + p.key = key; + p.value = value; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr; p.stream = stream; run_memory_efficient_attention(p); - DUMP_ATTENTION("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); return Status::OK(); } #endif @@ -610,7 +636,7 @@ Status QkvToContext( // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; + : parameters.scale; float alpha = use_raw_attention_mask ? one : scale; cublasSetStream(cublas, stream); @@ -622,7 +648,7 @@ Status QkvToContext( q, qk_head_size, sequence_length * qk_head_size, &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_ATTENTION_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); @@ -656,7 +682,7 @@ Status QkvToContext( scratch1, scratch2, parameters.is_unidirectional)); } - DUMP_ATTENTION_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = qkv; @@ -670,7 +696,7 @@ Status QkvToContext( // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, max_threads_per_block, false, temp_output, data.output); - DUMP_ATTENTION("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size); return result; } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index c7e5d34e1691b..93e5e59ed00ae 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -94,6 +94,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && + (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); @@ -111,6 +112,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = !disable_fused_runner_ && fused_cross_attention_kernel == nullptr && + value != nullptr && // fused runner requires packed qkv instead of packed kv (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && @@ -162,10 +164,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = nullptr; - data.bias = reinterpret_cast(bias->Data()); + data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 38bcbc298b939..a239e528af148 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -19,6 +19,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); @@ -71,6 +73,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); @@ -144,6 +149,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -192,6 +199,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc new file mode 100644 index 0000000000000..2b13cdbd803ef --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/bias_split_gelu.h" +#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BiasSplitGelu, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + BiasSplitGelu); + +REGISTER_KERNEL_TYPED(MLFloat16); +REGISTER_KERNEL_TYPED(float); + +using namespace ONNX_NAMESPACE; + +template +BiasSplitGelu::BiasSplitGelu(const OpKernelInfo& op_info) : CudaKernel(op_info) { +} + +template +Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 dimensions, got ", input_dims.size()); + } + + if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "hidden size should be 2560, 5120 or 10240, got ", input_dims[2]); + } + + const Tensor* bias = context->Input(1); + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimensions, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of input and bias are not the same"); + } + + TensorShapeVector output_shape = input->Shape().AsShapeVector(); + output_shape[2] = input_dims[2] / 2; + Tensor* output = context->Output(0, output_shape); + + typedef typename ToCudaType::MappedType CudaT; + const int32_t grid_size = static_cast(input_dims[0] * input_dims[1]); + const int32_t half_hidden_size = static_cast(input_dims[2] / 2); + LaunchBiasSplitGeluKernel(Stream(context), grid_size, half_hidden_size, + reinterpret_cast(input->Data()), + reinterpret_cast(bias->Data()), + reinterpret_cast(output->MutableData())); + + CUDA_RETURN_IF_ERROR(cudaPeekAtLastError()); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h new file mode 100644 index 0000000000000..feec45600bbce --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class BiasSplitGelu final : public CudaKernel { + public: + BiasSplitGelu(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu new file mode 100644 index 0000000000000..3cb95dad26b36 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The CUDA kernel is modified from SplitGelu plugin of TensorRT 8.5. +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) { + int32_t index_input = blockIdx.x * HHS * 2 + threadIdx.x; + int32_t index_output = blockIdx.x * HHS + threadIdx.x; + int32_t index_bias = threadIdx.x; + +#pragma unroll + for (int32_t i = 0; i < HHS / TPB; ++i) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + auto value_left = (float)(input[index_input] + bias[index_bias]); + auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]); +#else + auto value_left = (float)(input[index_input]) + (float)(bias[index_bias]); + auto value_right = (float)(input[index_input + HHS]) + (float)(bias[index_bias + HHS]); +#endif + // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0) + float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); + float result = value_left * gelu_right; + output[index_output] = static_cast(result); + index_input += TPB; + index_output += TPB; + index_bias += TPB; + } + return; +} + +template +void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + T const* input, T const* bias, T* output) { + constexpr int32_t TPB = 256; // thread per block + switch (half_hidden_size) { + case 1280: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 2560: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 5120: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); + +template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + float const* input, float const* bias, float* output); + +template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + half const* input, half const* bias, half* output); +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h new file mode 100644 index 0000000000000..a04201bd12e3c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/common/status.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + T const* input, T const* bias, T* output); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc new file mode 100644 index 0000000000000..36a2bd11257d6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/group_norm.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define GROUP_NORM_TYPES float, MLFloat16 + +ONNX_OPERATOR_KERNEL_EX( + GroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + +using namespace ONNX_NAMESPACE; + +namespace { +template +struct DispatchGroupNorm { + Status operator()(cudaStream_t stream, + Tensor* output, + const Tensor* input, + const Tensor* gamma, + const Tensor* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_swish_activation) { + typedef typename ToCudaType::MappedType CudaT; + return LaunchGroupNormKernel( + stream, + reinterpret_cast(output->MutableData()), + reinterpret_cast(input->Data()), + gamma->Data(), + beta->Data(), + workspace, + epsilon, + batch_size, + num_channels, + height, + width, + num_groups, + use_swish_activation); + } +}; + +} // namespace + +GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); + ORT_ENFORCE(epsilon_ >= 0); + + int64_t num_groups; + ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); + ORT_ENFORCE(num_groups >= 0); + num_groups_ = static_cast(num_groups); + + int64_t activation; + ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); + ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish + use_swish_activation_ = (activation == 1); +} + +Status GroupNorm::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* gamma = context->Input(1); + const Tensor* beta = context->Input(2); + Tensor* output = context->Output(0, input->Shape()); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 4 dimensions, got ", input_dims.size()); + } + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != input_dims[3]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in gamma and input does not match"); + } + + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != input_dims[3]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in beta and input does not match"); + } + + // Input and output format is NHWC + int batch_size = static_cast(input_dims[0]); + int num_channels = static_cast(input_dims[3]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + + utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); + return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + epsilon_, + batch_size, + num_channels, + height, + width, + num_groups_, + use_swish_activation_); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h new file mode 100644 index 0000000000000..8578a1642198f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +class GroupNorm final : public CudaKernel { + public: + GroupNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + bool use_swish_activation_; + float epsilon_; + int num_groups_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu new file mode 100644 index 0000000000000..01ba078b4be77 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -0,0 +1,475 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +static inline int32_t divUp(int32_t m, int32_t n) { + return (m + n - 1) / n; +} + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sumSq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +template +struct GroupNormNHWCParams { + // The output buffer. Layout NHWC. + T* dst; + // The input buffer. Layout NHWC. + T const* src; + // The gamma scaling factor. + float const* gamma; + // The beta term to add in GN. + float const* beta; + // The temporary buffer to do the global parallel reduction. Size: + // BLOCKS_PER_BATCH x C x 2. + float* redBuffer; + + // The number of instances in the batch. + int32_t n; + // The height and width of each activation map. + int32_t h; + int32_t w; + // The number of channels. + int32_t c; + // The number of groups. + int32_t groups; + // Do we apply the Swish activation function? + bool withSwish; + + // Precomputed values and parameters to control the execution of the kernels. + + // The number of activations per instance (h * w) and the number of + // activations per block. + int32_t hw; + int32_t hwPerBlock; + // The number of channels per group and blocks per activation in the C + // dimension. + int32_t cPerBlock; + int32_t cPerGroup; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hwc in floats (to compute mean/var). + float invHWC; + // The precomputed number of groups per block. + int32_t groupsPerBlock; +}; + +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); + +template <> +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; +} + +template +__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage tempStorage; + // Allocate shared memory for the groups. We could reduce the amount of shared + // memory reserved. + __shared__ float2 smem[tTHREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + // Iterate over the activations to compute the sums. + if (ci < params.c) { + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; + UpdateSum(params.src, offset, sum, sumSq); + } + } + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = threadIdx.x * 2 / params.cPerGroup; + int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + + // Do the segmented scan. + GroupSums out; + BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced + // stores later). + if (cj == params.cPerGroup - 2) { //2 channels per thread + smem[gi] = make_float2(out.sum, out.sumSq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The global group index. + int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + return; + } + + // The first threads (those storing to global memory, load the values). + float2 sums = smem[threadIdx.x]; + + // Store to global memory. + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); +} + +template +void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { + // Make sure the values are as we expect. + ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); + // Make sure a group does not span multiple blocks. + ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); + + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNHWCSumKernel<<>>(params); + break; + case 480: + groupNormNHWCSumKernel<<>>(params); + break; + case 256: + groupNormNHWCSumKernel<<>>(params); + break; + case 128: + groupNormNHWCSumKernel<<>>(params); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template +__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); + +template <> +__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, + float2& gammaF2, float2& betaF2, bool swish) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (swish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); +} + +template <> +__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, + float2& gammaF2, float2& betaF2, bool swish) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (swish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast(&dst[offset]) = f2; +} + +template +__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + if (ci >= params.c) { + return; + } + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = ci / params.cPerGroup; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sumSq = 0.F; + if (gi < params.groups) { + sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; + sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + } + + // Load gamma/beta. + float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + + // Compute the mean. + float mean = sum * params.invHWC; + // Compute the variance. + float var = sumSq * params.invHWC - (mean * mean); + // Compute the inverse of the stddev. + float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + } +} + +template +void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { + // Make sure the dimensions are aligned with what we expect. + ORT_ENFORCE(params.c % params.cPerBlock == 0); + // Make sure a group does not span multiple blocks. + ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); + + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNHWCScaleKernel<<>>(params); + break; + case 480: + groupNormNHWCScaleKernel<<>>(params); + break; + case 256: + groupNormNHWCScaleKernel<<>>(params); + break; + case 128: + groupNormNHWCScaleKernel<<>>(params); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { + int32_t maxDivisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { + maxDivisor = divisor1; + } + if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { + maxDivisor = divisor2; + } + } + } + return maxDivisor; +} + +template +Status LaunchGroupNormKernel( + cudaStream_t stream, + T* output, + const T* input, + const float* gamma, + const float* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_swish_activation) { + if (batch_size > static_cast(kMaxGroupNormBatchSize)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "only support batch_size <= 32. Got", batch_size); + } + + if (num_groups != static_cast(kGroupNormNumberOfGroups)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "only num_groups=32 is supported. Got", num_groups); + } + + GroupNormNHWCParams params; + int32_t cPerBlock = 320; + int32_t maxBlocksPerHW = 1024; + switch (num_channels) { + case 960: + case 1920: + cPerBlock = 480; + break; + case 512: + case 256: + cPerBlock = 256; + break; + case 128: + cPerBlock = 128; + break; + default: + cPerBlock = 320; + } + + params.withSwish = use_swish_activation; + params.dst = output; + params.src = input; + params.gamma = gamma; + params.beta = beta; + params.redBuffer = reinterpret_cast(workspace); + params.n = batch_size; + params.h = height; + params.w = width; + params.c = num_channels; + params.groups = num_groups; + params.hw = params.h * params.w; + const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); + params.hwPerBlock = divUp(params.hw, blocksPerHW); + params.cPerBlock = cPerBlock; + params.cPerGroup = params.c / params.groups; + params.hwc = params.hw * params.c; + params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); + params.groupsPerBlock = cPerBlock / params.cPerGroup; + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("input", input, batch_size, num_channels, height * width); + DUMP_TENSOR("gamma", gamma, 1, num_channels); + DUMP_TENSOR("beta", beta, 1, num_channels); + cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); + groupNormNHWCSum(params, stream); + DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + groupNormNHWCScale(params, stream); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); +} + +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, + const half* input, const float* gamma, const float* beta, void* workspace, + float epsilon, int batch_size, int num_channels, + int height, int width, int num_groups, bool swish); + +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, + const float* input, const float* gamma, const float* beta, void* workspace, + float epsilon, int batch_size, int num_channels, + int height, int width, int num_groups, bool swish); +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h new file mode 100644 index 0000000000000..c7e9245050ee6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/common/status.h" +#include +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +constexpr size_t kMaxGroupNormBatchSize = 32; +constexpr size_t kGroupNormNumberOfGroups = 32; + +constexpr size_t GetGroupNormWorkspaceSizeInBytes() { + // Two buffers for sum and squared sum + return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; +} + +template +Status LaunchGroupNormKernel( + cudaStream_t stream, + T* output, // normalized output tensor + const T* input, // input tensor + const float* gamma, // gamma (also known as weight or scale) + const float* beta, // beta (also known as bias) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_swish_activation // Whether there is Swish activation after group normalization +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc new file mode 100644 index 0000000000000..79f0a18ba515f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/span_utils.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/nn/conv.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + NhwcConv, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Conv); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index 39c3bb282d912..48881ddca4063 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -9,10 +9,10 @@ namespace contrib { namespace cuda { template -class FusedConv : public onnxruntime::cuda::Conv { +class FusedConv : public onnxruntime::cuda::Conv { public: - using Base = onnxruntime::cuda::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { + using Base = onnxruntime::cuda::Conv; + FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { std::string activation; if (info.GetAttr("activation", &activation) == Status::OK() && MapMode(activation) == Status::OK() && diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index 6c0f7f69c58a1..741f9ac259da1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -11,7 +11,7 @@ namespace contrib { namespace cuda { namespace transformers { -#ifdef DEBUG_GENERATION +#if DUMP_TENSOR_LEVEL > 0 template class PinnedHostBuffer { public: diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index b4ad4d64e7ddb..68e3985651123 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -127,32 +127,41 @@ void RestorePaddingTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) - // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, kv_sequence_length, num_heads, 2, head_size) + // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or nullptr // Output 0 has shape (batch_size, sequence_length, v_hidden_size) // Type inference ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); // Shape inference - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + if (hasInputShape(ctx, 0)) { auto& query_shape = getInputShape(ctx, 0); auto& query_dims = query_shape.dim(); if (query_dims.size() != 3) { fail_shape_inference("Inputs 0 (query) shall be 3 dimensions"); } - auto& value_shape = getInputShape(ctx, 2); - auto& value_dims = value_shape.dim(); - if (value_dims.size() != 3) { - fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + if (hasInputShape(ctx, 2)) { + auto& value_shape = getInputShape(ctx, 2); + auto& value_dims = value_shape.dim(); + if (value_dims.size() != 3) { + fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + *output_shape.add_dim() = value_dims[2]; + updateOutputShape(ctx, 0, output_shape); } - ONNX_NAMESPACE::TensorShapeProto output_shape; - *output_shape.add_dim() = query_dims[0]; - *output_shape.add_dim() = query_dims[1]; - *output_shape.add_dim() = value_dims[2]; - updateOutputShape(ctx, 0, output_shape); + if (hasInputShape(ctx, 1)) { + auto& key_shape = getInputShape(ctx, 1); + if (key_shape.dim().size() == 5) { + ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx); + } + } } } @@ -287,16 +296,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Input(1, "key", - "Key with shape (batch_size, kv_sequence_length, hidden_size)", + "Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)", "T") .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, v_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "bias", "Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection", - "T") + "T", + OpSchema::Optional) .Input(4, "key_padding_mask", "Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)", diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc new file mode 100644 index 0000000000000..14a267357371d --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/constants.h" +#include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/contrib_ops/onnx_function_util.h" +#include "core/graph/contrib_ops/shape_inference_functions.h" + +// Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from +// ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build +#if defined(_WIN32) && !defined(NDEBUG) +#pragma warning(disable : 26426) +#endif + +namespace onnxruntime { +namespace contrib { +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::TensorShapeProto; +#ifndef NDEBUG +using ONNX_NAMESPACE::DbgOperatorSetTracker; +#endif + +constexpr const char* GroupNorm_ver1_doc = R"DOC( +Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494). + +This operator transforms input according to + y = gamma * (x - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GroupNorm, 1, + OpSchema() + .SetDoc(GroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for Swish", + AttributeProto::INT) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") + .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + +constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( +A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left +tensor multiplies the Gelu activation result of right tensor. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + BiasSplitGelu, 1, + OpSchema() + .SetDoc(BiasSplitGelu_ver1_doc) + .Input(0, + "X", + "Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension", + "T") + .Input(1, + "bias", + "Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor", + "T") + .Output(0, + "Y", + "The output tensor with dimensions (N, S, D/2)", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { + auto& input_shape = getInputShape(ctx, 0); + if (input_shape.dim().size() != 3) { + fail_shape_inference("input shall be 3 dimensions"); + } + + auto& bias_shape = getInputShape(ctx, 1); + if (bias_shape.dim().size() != 1) { + fail_shape_inference("bias shall be 1 dimension"); + } + + TensorShapeProto output_shape; + *output_shape.add_dim() = input_shape.dim(0); + *output_shape.add_dim() = input_shape.dim(1); + if (bias_shape.dim(0).has_dim_value()) { + output_shape.add_dim()->set_dim_value(bias_shape.dim(0).dim_value() / 2); + } else { + output_shape.add_dim(); + } + + updateOutputShape(ctx, 0, output_shape); + } + })); +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 1f0af31a4bdd0..a511d01fe1624 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -49,6 +49,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BeamSearch); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasDropout); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BitmaskBiasDropout); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasGelu); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSplitGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSoftmax); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BifurcationDetector); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CDist); @@ -69,6 +70,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupNorm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite); @@ -135,6 +137,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -155,6 +158,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/providers/cpu/nn/conv_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_attributes.h index 51a1e7acafe11..b31030acc52c1 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_attributes.h @@ -73,7 +73,7 @@ struct ConvAttributes { ~ConvAttributes() = default; - Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape) const { + Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape, bool weight_channels_last = false) const { if (kernel_shape_specified) { kernel_shape = kernel_shape_; if (kernel_shape.size() + 2 != weight_shape.NumDimensions()) { @@ -82,15 +82,20 @@ struct ConvAttributes { " W: ", weight_shape.ToString().c_str()); } for (size_t i = 0; i < kernel_shape.size(); ++i) { - if (kernel_shape[i] != weight_shape[i + 2]) { + if (kernel_shape[i] != weight_shape[i + (weight_channels_last ? 1 : 2)]) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", weight_shape.ToString().c_str()); + " W: ", weight_shape.ToString().c_str(), + " channels_last: ", weight_channels_last); } } } else { auto weight_dims = weight_shape.GetDims(); - kernel_shape.assign(weight_dims.begin() + 2, weight_dims.end()); + if (weight_channels_last) { + kernel_shape.assign(weight_dims.begin() + 1, weight_dims.end() - 1); + } else { + kernel_shape.assign(weight_dims.begin() + 2, weight_dims.end()); + } } return Status::OK(); @@ -98,7 +103,8 @@ struct ConvAttributes { Status ValidateInputShape(const TensorShape& input_shape, const TensorShape& weight_shape, - bool channels_last = false) const { + bool input_channels_last = false, + bool weight_channels_last = false) const { if (input_shape.NumDimensions() != weight_shape.NumDimensions()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "X num_dims does not match W num_dims.", " X: ", input_shape.ToString().c_str(), @@ -106,9 +112,9 @@ struct ConvAttributes { } const int64_t M = weight_shape[0]; - const int64_t C = channels_last ? input_shape.GetDims().back() : input_shape[1]; + const int64_t C = input_channels_last ? input_shape.GetDims().back() : input_shape[1]; - if (C != weight_shape[1] * group) { + if (C != (weight_channels_last ? weight_shape.GetDims().back() : weight_shape[1]) * group) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input channels C is not equal to kernel channels * group.", " C: ", C, " kernel channels: ", weight_shape[1], diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index d62a651880a85..4c9cbbe605a7a 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -42,6 +42,12 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat return Status::OK(); } +Status CudnnTensor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w) { + ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); + CUDNN_RETURN_IF_ERROR(cudnnSetTensor4dDescriptor(tensor_, format, dataType, n, c, h, w)); + return Status::OK(); +} + Status CudnnTensor::Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); CUDNN_RETURN_IF_ERROR(cudnnDeriveBNTensorDescriptor(tensor_, x_desc, mode)); @@ -113,15 +119,23 @@ Status CudnnFilterDescriptor::Set(gsl::span filter_dims, cudnnDat return Status::OK(); } +Status CudnnFilterDescriptor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int k, int c, int h, int w) { + if (!desc_) + CUDNN_RETURN_IF_ERROR(cudnnCreateFilterDescriptor(&desc_)); + + CUDNN_RETURN_IF_ERROR(cudnnSetFilter4dDescriptor(desc_, dataType, format, k, c, h, w)); + return Status::OK(); +} + template cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN engine currently supports only single/double/half/int8/uint8 precision data types. Got:", - typeid(ElemType).name()); + typeid(ElemType).name()); // Not reachable but GCC complains return CUDNN_DATA_FLOAT; } -template<> +template <> cudnnDataType_t CudnnTensor::GetDataType() { return CUDNN_DATA_FLOAT; } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index f104373b9413a..ba75ab4f2c029 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -18,6 +18,8 @@ class CudnnTensor final { Status Set(gsl::span input_dims, cudnnDataType_t dataType); Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode); + // Set 4D tensor format (for NHWC) + Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w); operator cudnnTensorDescriptor_t() const { return tensor_; } @@ -58,6 +60,9 @@ class CudnnFilterDescriptor final { Status Set(gsl::span filter_dims, cudnnDataType_t data_typ); + // Set 4D filter where k is output channels, c is input channels, h and w is rows and columns per filter. + Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int k, int c, int h, int w); + operator cudnnFilterDescriptor_t() const { return desc_; } private: diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index f1590bc51388d..b0df77db96744 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -20,7 +20,7 @@ namespace cuda { T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ + Conv); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Conv, \ kOnnxDomain, \ @@ -28,14 +28,14 @@ namespace cuda { T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); + Conv); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) -template -const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { +template +const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_FFT, CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, @@ -52,7 +52,7 @@ cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { - // TODO: get maximum available size from memory areana + // TODO: get maximum available size from memory arena size_t free, total; CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); // Assuming 10% of fragmentation @@ -86,8 +86,8 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); } -template -Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { +template +Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { //set X const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -99,6 +99,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const const TensorShape& w_shape = W->Shape(); auto w_dims = w_shape.AsShapeVector(); s_.w_data = reinterpret_cast(W->Data()); + + // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. + constexpr bool channels_last = NHWC; + if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + } + // set B if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -125,48 +132,60 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const s_.cached_benchmark_results.clear(); } - const int64_t N = X->Shape()[0]; - const int64_t M = W->Shape()[0]; - - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); - auto rank = kernel_shape.size(); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, channels_last)); + + const size_t kernel_rank = kernel_shape.size(); + ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { - pads.resize(rank * 2, 0); + pads.resize(kernel_rank * 2, 0); } TensorShapeVector dilations(conv_attrs_.dilations); if (dilations.empty()) { - dilations.resize(rank, 1); + dilations.resize(kernel_rank, 1); } TensorShapeVector strides(conv_attrs_.strides); if (strides.empty()) { - strides.resize(rank, 1); + strides.resize(kernel_rank, 1); } TensorShapeVector y_dims; - y_dims.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' - y_dims.insert(y_dims.begin(), {N, M}); + y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C' - TensorShapeVector y_dims_with_adjusted_pads; - y_dims_with_adjusted_pads.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' - y_dims_with_adjusted_pads.insert(y_dims_with_adjusted_pads.begin(), {N, M}); + const int64_t N = X->Shape()[0]; + const int64_t M = W->Shape()[0]; + if (channels_last) { + y_dims.push_back(N); + } else { + y_dims.insert(y_dims.begin(), {N, M}); + } bool post_slicing_required = false; TensorShapeVector slice_starts; - slice_starts.reserve(rank); + slice_starts.reserve(kernel_rank); TensorShapeVector slice_ends; - slice_ends.reserve(rank); + slice_ends.reserve(kernel_rank); TensorShapeVector slice_axes; - slice_axes.reserve(rank); + slice_axes.reserve(kernel_rank); + + const size_t spatial_dim_start = channels_last ? 1 : 2; + const size_t spatial_dim_end = spatial_dim_start + kernel_rank; + TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(x_shape.Slice(2), kernel_shape, + TensorShapeVector y_dims_with_adjusted_pads(y_dims); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, post_slicing_required, slice_starts, slice_ends, slice_axes)); + if (channels_last) { + y_dims.push_back(M); + y_dims_with_adjusted_pads.push_back(M); + } + ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); s_.y_dims = gsl::make_span(y_dims); s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; @@ -190,7 +209,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; - if (rank < 2) { + if (kernel_rank < 2) { // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] // especially for EXHAUSTIVE algo search which may result in a better algo selection. // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to @@ -203,7 +222,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); w_dims.insert(w_dims.begin() + 2, 1); - pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.begin(), 0); kernel_shape.insert(kernel_shape.begin(), 1); strides.insert(strides.begin(), 1); @@ -212,7 +231,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const x_dims_cudnn.push_back(1); y_dims_cudnn.push_back(1); w_dims.push_back(1); - pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.end(), 0); kernel_shape.push_back(1); strides.push_back(1); @@ -220,16 +239,43 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const } } - if (w_dims_changed) - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + if (w_dims_changed) { + if (!channels_last) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(w_dims[0]), + static_cast(w_dims[3]), + static_cast(w_dims[1]), + static_cast(w_dims[2]))); + } + } // We must delay returning early until here so that the weight dims have been cached properly if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + if (channels_last) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(x_dims_cudnn[0]), + static_cast(x_dims_cudnn[3]), + static_cast(x_dims_cudnn[1]), + static_cast(x_dims_cudnn[2]))); + + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(y_dims_cudnn[0]), + static_cast(y_dims_cudnn[3]), + static_cast(y_dims_cudnn[1]), + static_cast(y_dims_cudnn[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + } + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); @@ -331,8 +377,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const return Status::OK(); } -template -Status Conv::ComputeInternal(OpKernelContext* context) const { +template +Status Conv::ComputeInternal(OpKernelContext* context) const { std::lock_guard lock(s_.mutex); ORT_RETURN_IF_ERROR(UpdateState(context)); if (s_.Y->Shape().Size() == 0) { @@ -367,7 +413,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { s_.slice_ends, s_.slice_axes, s_.element_size)); } return Status::OK(); -} // namespace cuda +} CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() : desc_(nullptr) { } @@ -424,5 +470,11 @@ Status CudnnConvolutionDescriptor::Set( return Status::OK(); } +#ifndef DISABLE_CONTRIB_OPS +// template instantiation for NhwcConv +template class Conv; +template class Conv; +#endif + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index ae179de0070b0..07825b93204ca 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -177,7 +177,9 @@ enum : size_t { AlgoSearchWorkspaceSize = 32 * 1024 * 1024, }; -template +// ONNX Conv operator uses NCHW format for input, weights and output. +// NhwcConv contrib ops uses NHWC format: last dimension of input, weights and output are channels. +template class Conv : public CudaKernel { public: using CudaT = typename ToCudaType::MappedType; diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ed94a01f562ef..689235b630d94 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -200,6 +200,8 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PythonOp": self._infer_PythonOp, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + "GroupNorm": self._infer_GroupNorm, + "BiasSplitGelu": self._infer_BiasSplitGelu, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -434,6 +436,8 @@ def _onnx_infer_single_node(self, node): "SkipLayerNormalization", "PythonOp", "MultiHeadAttention", + "GroupNorm", + "BiasSplitGelu", ] if not skip_infer: @@ -1963,53 +1967,62 @@ def _infer_ZipMap(self, node): def _infer_Attention(self, node): shape = self._get_shape(node, 0) shape_bias = self._get_shape(node, 2) - assert len(shape) == 3 and len(shape_bias) == 1 - qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") - if qkv_hidden_sizes_attr is not None: - assert len(qkv_hidden_sizes_attr) == 3 - shape[2] = int(qkv_hidden_sizes_attr[2]) - else: - shape[2] = int(shape_bias[0] / 3) - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) - - if len(node.output) > 1: - # input shape: (batch_size, sequence_length, hidden_size) - # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) - # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) - # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length - input_shape = self._get_shape(node, 0) - past_shape = self._get_shape(node, 4) - mask_shape = self._get_shape(node, 3) - if len(past_shape) == 5: - if len(mask_shape) in [2, 3]: - past_shape[3] = mask_shape[-1] - elif isinstance(input_shape[1], int) and isinstance(past_shape[3], int): - past_shape[3] = input_shape[1] + past_shape[3] - else: - past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" - vi = self.known_vi_[node.output[1]] - vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + if shape and len(shape) == 3 and shape_bias and len(shape_bias) == 1: + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[2] = int(qkv_hidden_sizes_attr[2]) + elif isinstance(shape_bias[0], int): + shape[2] = int(shape_bias[0] / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + if len(node.output) > 1: + # input shape: (batch_size, sequence_length, hidden_size) + # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) + # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) + # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length + input_shape = self._get_shape(node, 0) + past_shape = self._get_shape(node, 4) + mask_shape = self._get_shape(node, 3) + if past_shape and len(past_shape) == 5: + if mask_shape and len(mask_shape) in [2, 3]: + past_shape[3] = mask_shape[-1] + elif input_shape and len(input_shape) == 3: + if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): + past_shape[3] = input_shape[1] + past_shape[3] + else: + past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) def _infer_BiasGelu(self, node): self._propagate_shape_and_type(node) def _infer_MultiHeadAttention(self, node): # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) - # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + # Without packed KV: + # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) + # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + # With packed KV: + # Input 1 (key) has shape (batch_size, kv_sequence_length, num_heads, 2, head_size) + # Input 2 (value) is nullptr # Output 0 has shape (batch_size, sequence_length, v_hidden_size) query_shape = self._get_shape(node, 0) - value_shape = self._get_shape(node, 2) + key_shape = self._get_shape(node, 1) + if query_shape is not None and len(query_shape) == 3: - assert len(query_shape) == 3 and len(value_shape) == 3 - output_shape = query_shape - output_shape[2] = value_shape[2] + # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. + output_shape = query_shape + if key_shape and len(key_shape) == 3: + value_shape = self._get_shape(node, 2) + if value_shape and len(value_shape) == 3: + output_shape[2] = value_shape[2] - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_FastGelu(self, node): self._propagate_shape_and_type(node) @@ -2056,6 +2069,19 @@ def _infer_SkipLayerNormalization(self, node): if len(node.output) > 3: self._propagate_shape_and_type(node, 0, 3) + def _infer_GroupNorm(self, node): + self._propagate_shape_and_type(node) + + def _infer_BiasSplitGelu(self, node): + input_shape = self._get_shape(node, 0) + bias_shape = self._get_shape(node, 1) + if input_shape and bias_shape and isinstance(bias_shape[0], int): + output_shape = input_shape + output_shape[2] = int(bias_shape[0] / 2) + vi = self.known_vi_[node.output[0]] + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) + def _infer_PythonOp(self, node): output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 2151e6a21c5e7..0441ce494d560 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -19,11 +19,14 @@ class FusionAttentionUnet(Fusion): Fuse Attention subgraph of UNet into one Attention node. """ - def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool): + def __init__( + self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool, enable_packed_kv: bool + ): super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) self.hidden_size = hidden_size self.num_heads = num_heads self.is_cross_attention = is_cross_attention + self.enable_packed_kv = enable_packed_kv # Flags to show warning only once self.num_heads_warning = True @@ -103,8 +106,22 @@ def create_attention_node( is_self_attention = not self.is_cross_attention if is_self_attention: - if q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input: - logger.debug("q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input") + if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input: + logger.debug( + "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s", + q_matmul.input[0], + k_matmul.input[0], + v_matmul.input[0], + ) + return None + else: + if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input): + logger.debug( + "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s", + q_matmul.input[0], + k_matmul.input[0], + v_matmul.input[0], + ) return None if hidden_size > 0 and (hidden_size % num_heads) != 0: @@ -136,7 +153,7 @@ def create_attention_node( kw_in_size = kw.shape[0] vw_in_size = vw.shape[0] - assert qw_in_size == kw_in_size == vw_in_size + assert qw_in_size == kw_in_size and kw_in_size == vw_in_size if hidden_size > 0 and hidden_size != qw_in_size: raise ValueError( @@ -162,8 +179,63 @@ def create_attention_node( ) self.model.add_initializer(weight, self.this_graph_name) - else: + else: # cross attention attention_node_name = self.model.create_node_name("MultiHeadAttention") + if self.enable_packed_kv: + if kw.shape != vw.shape: + return None + + kw_in_size = kw.shape[0] + vw_in_size = vw.shape[0] + assert kw_in_size == vw_in_size + + qw_out_size = qw.shape[1] + kw_out_size = kw.shape[1] + vw_out_size = vw.shape[1] + assert qw_out_size == vw_out_size and kw_out_size == vw_out_size + + c = kw_in_size + n = num_heads + h = kw_out_size // num_heads + + # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape + kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h) + + matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") + weight = helper.make_tensor( + name=matmul_node_name + "_weight", + data_type=TensorProto.FLOAT, + dims=[kv_weight.shape[0], kv_weight.shape[1]], + vals=kv_weight.flatten().tolist(), + ) + + self.model.add_initializer(weight, self.this_graph_name) + + matmul_node = helper.make_node( + "MatMul", + inputs=[k_matmul.input[0], matmul_node_name + "_weight"], + outputs=[matmul_node_name + "_out"], + name=matmul_node_name, + ) + self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name + + shape_tensor = helper.make_tensor( + name=matmul_node_name + "_reshape_shape", + data_type=TensorProto.INT64, + dims=[5], + vals=[0, 0, n, 2, h], + ) + self.model.add_initializer(shape_tensor, self.this_graph_name) + + reshape_node = helper.make_node( + "Reshape", + inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], + outputs=[k_matmul.output[0]], + name=matmul_node_name + "_reshape", + ) + self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name + self.nodes_to_add.extend([matmul_node, reshape_node]) + self.nodes_to_remove.extend([k_matmul, v_matmul]) # No bias, use zeros qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) @@ -184,12 +256,18 @@ def create_attention_node( attention_node_name + "_qkv_bias", ] else: - attention_inputs = [ - q_matmul.output[0], - k_matmul.output[0], - v_matmul.output[0], - attention_node_name + "_qkv_bias", - ] + if not self.enable_packed_kv: + attention_inputs = [ + q_matmul.output[0], + k_matmul.output[0], + v_matmul.output[0], + attention_node_name + "_qkv_bias", + ] + else: + attention_inputs = [ + q_matmul.output[0], + k_matmul.output[0], + ] attention_node = helper.make_node( "Attention" if is_self_attention else "MultiHeadAttention", @@ -200,12 +278,23 @@ def create_attention_node( attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + counter_name = ( + "Attention (self attention)" + if is_self_attention + else "MultiHeadAttention ({})".format( + "cross attention with packed kv" if self.enable_packed_kv else "cross attention" + ) + ) + self.increase_counter(counter_name) return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - node_before_layernorm = self.model.match_parent( - normalize_node, "Add" if self.is_cross_attention else "Reshape", 0 - ) + node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) + + # In SD 1.5, for self attention, LayerNorm has parent Reshape + if node_before_layernorm is None and not self.is_cross_attention: + node_before_layernorm = self.model.match_parent(normalize_node, "Reshape", 0) + if node_before_layernorm is None: return @@ -241,11 +330,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0]) if qk_nodes is not None: - (softmax_qk, mul_qk, matmul_qk) = qk_nodes + (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes else: qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) if qk_nodes is not None: - (softmax_qk, add_zero, mul_qk, matmul_qk) = qk_nodes + (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes else: logger.debug("fuse_attention: failed to match qk path") return diff --git a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py new file mode 100644 index 0000000000000..106d3de25d39d --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Dict + +from fusion_base import Fusion +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionBiasSplitGelu(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "BiasSplitGelu", "Gelu") + + def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + [root] --->Add --------------------> Slice ---------------> Mul --> + | ^ ^ + | | | + +----------------------------+---Slice --> Gelu---+ + | | ^ + | |-----| + | | | + | Mul Mul + | ^ ^ + v | | + Shape ---> Gather --> Add --> Div --+ + """ + if gelu_node.output[0] not in input_name_to_nodes: + return + children = input_name_to_nodes[gelu_node.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return + mul_after_gelu = children[0] + + slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node) + if slice_before_gelu is None: + return + + if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3: + return + + add_output = slice_before_gelu.input[0] + + start_index_nodes = self.model.match_parent_path( + slice_before_gelu, + ["Div", "Add", "Gather", "Shape", "Add"], + [1, 0, 0, 0, 0], + output_name_to_node, # Mul(1) is optional + ) + if start_index_nodes is None: + start_index_nodes = self.model.match_parent_path( + slice_before_gelu, + ["Mul", "Div", "Add", "Gather", "Shape", "Add"], + [1, 0, 0, 0, 0, 0], + output_name_to_node, + ) + + if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output: + return + + end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node) + + if ( + end_index_nodes is None or end_index_nodes[1] not in start_index_nodes + ): # the Div is parent of both two Mul nodes + return + + slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node) + if slice_before_mul is None: + return + + if ( + slice_before_mul.input[2] != slice_before_gelu.input[1] + ): # end index of slice_before_mul is start index of slice_before_gelu + return + + subgraph_nodes = start_index_nodes + [ + end_index_nodes[0], + mul_after_gelu, + gelu_node, + slice_before_mul, + slice_before_gelu, + ] + subgraph_output = mul_after_gelu.output[0] + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node + ): + logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.") + return + + add_node = start_index_nodes[-1] + bias_index, _value = self.model.get_constant_input(add_node) + if not isinstance(bias_index, int): + return + self.nodes_to_remove.extend(subgraph_nodes) + node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu") + fused_node = helper.make_node( + "BiasSplitGelu", + inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]], + outputs=[subgraph_output], + name=node_name, + ) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + self.node_name_to_graph_name[node_name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py new file mode 100644 index 0000000000000..a0a4d7c16de0b --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -0,0 +1,198 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Dict + +import numpy as np +from fusion_base import Fusion +from fusion_utils import FusionUtils +from onnx import TensorProto, helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionGroupNorm(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "GroupNorm", "Add") + + def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + Fuse Group Normalization subgraph into one node GroupNorm. + The following is the pattern with swish activation: + +----------------Shape-------------------------------+ + | | + | (0, 32, -1) v (512x1x1) (512x1x1) (optional) + [Root] --> Reshape -------> InstanceNormalization --> Reshape ---> Mul --> Add --> Mul--> [output] + Bx512xHxW (scale=ones(32), B=zeros(32)) | ^ Bx512xHxW + | | + +--->Sigmoid (optional) + The Mul and Sigmoid before output is for Swish activation. They are optional. + """ + nodes = self.model.match_parent_path( + add_node, ["Mul", "Reshape", "InstanceNormalization", "Reshape"], [0, 0, 0, 0], output_name_to_node + ) + if nodes is None: + return + + weight_mul, reshape_4d, instance_norm, reshape_3d = nodes + root = reshape_3d.input[0] + + parents = self.model.match_parent_path(reshape_4d, ["Shape"], [1], output_name_to_node) + if parents is None: + return + if parents[0].input[0] != root: + return + shape_node = parents[0] + + # Check whether it has swish activation. + swish_mul = self.model.find_first_child_by_type(add_node, "Mul") + swish_sigmoid = None + if swish_mul is not None: + sigmoid_path = self.model.match_parent_path(swish_mul, ["Sigmoid"], [None], output_name_to_node) + if sigmoid_path is not None: + swish_sigmoid = sigmoid_path[0] + + weight_input = weight_mul.input[1 - self.model.input_index(reshape_4d.output[0], weight_mul)] + if not self.model.is_constant_with_specified_dimension(weight_input, 3, "group norm weight"): + return + + bias_input = add_node.input[1 - self.model.input_index(weight_mul.output[0], add_node)] + if not self.model.is_constant_with_specified_dimension(bias_input, 3, "layernorm bias"): + return + + weight = self.model.get_constant_value(weight_input) + if weight is None: + return + + if not (len(weight.shape) == 3 and weight.shape[1] == 1 and weight.shape[2] == 1): + return + + bias = self.model.get_constant_value(bias_input) + if bias is None: + return + if not (len(bias.shape) == 3 and bias.shape[1] == 1 and bias.shape[2] == 1): + return + + weight_elements = int(np.prod(weight.shape)) + bias_elements = int(np.prod(bias.shape)) + if weight_elements != bias_elements: + return + + instance_norm_scale = self.model.get_constant_value(instance_norm.input[1]) + if instance_norm_scale is None: + return + instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) + if instance_norm_bias is None: + return + + if not ( + len(instance_norm_scale.shape) == 1 + and len(instance_norm_bias.shape) == 1 + and instance_norm_scale.shape == instance_norm_bias.shape + and instance_norm_scale.shape[0] == 32 + ): + logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0]) + return + + if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): + return + if not np.allclose(np.zeros_like(instance_norm_bias), instance_norm_bias): + return + + group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") + + if weight_elements not in [320, 640, 960, 1280, 1920, 2560] + [128, 256, 512]: + logger.info("GroupNorm channels=%d", weight_elements) + + gamma = helper.make_tensor( + name=group_norm_name + "_gamma", + data_type=TensorProto.FLOAT, + dims=[weight_elements], + vals=weight.flatten().tolist(), + ) + self.model.add_initializer(gamma, self.this_graph_name) + + beta = helper.make_tensor( + name=group_norm_name + "_beta", + data_type=TensorProto.FLOAT, + dims=[bias_elements], + vals=bias.flatten().tolist(), + ) + self.model.add_initializer(beta, self.this_graph_name) + + last_node = add_node + subgraph_nodes = [add_node, weight_mul, reshape_4d, instance_norm, reshape_3d, shape_node] + has_swish_activation = swish_mul and swish_sigmoid + if swish_mul and swish_sigmoid: + subgraph_nodes.extend([swish_mul, swish_sigmoid]) + last_node = swish_mul + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + last_node.output, + input_name_to_nodes, + output_name_to_node, + ): + self.nodes_to_remove.extend([last_node]) + else: + self.nodes_to_remove.extend(subgraph_nodes) + + # instance_norm_scale might from Constant node. Use prune graph to clear it. + self.prune_graph = True + + # Right now GroupNorm only support float16 input. Need add a Cast in fp32 model. + utils = FusionUtils(self.model) + + input = root + output = last_node.output[0] + if weight.dtype == np.float32: + # Add a Cast node to get float16 input for GroupNorm + cast_input, _cast_node = utils.cast_input(root, "float16") + input = cast_input + + # Add a Cast node to convert back to float32 after GroupNorm + output = group_norm_name + "_out" + cast_node = helper.make_node("Cast", inputs=[group_norm_name + "_out"], outputs=[last_node.output[0]]) + cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.FLOAT))]) + self.model.add_node(cast_node) + + # NCHW to NHWC + transpose_input = helper.make_node( + "Transpose", + [input], + [input + "_NHWC"], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"), + perm=[0, 2, 3, 1], + ) + + new_node = helper.make_node( + "GroupNorm", + inputs=[input + "_NHWC", group_norm_name + "_gamma", group_norm_name + "_beta"], + outputs=[output + "_NHWC"], + name=group_norm_name, + ) + + new_node.attribute.extend(instance_norm.attribute) + new_node.attribute.extend([helper.make_attribute("groups", 32)]) + new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)]) + new_node.domain = "com.microsoft" + + # NHWC to NCHW + transpose_output = helper.make_node( + "Transpose", + [output + "_NHWC"], + [output], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"), + perm=[0, 3, 1, 2], + ) + + self.nodes_to_add.append(new_node) + self.nodes_to_add.append(transpose_input) + self.nodes_to_add.append(transpose_output) + + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name + self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 9a5359b58caa6..cdfa2c626fc57 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -6,9 +6,16 @@ class AttentionMaskFormat: + # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance. MaskIndexEnd = 0 + + # For experiment only. Do not use it in production. MaskIndexEndAndStart = 1 + + # Raw attention mask with 0 means padding (or no attention) and 1 otherwise. AttentionMask = 2 + + # No attention mask NoMask = 3 @@ -36,7 +43,17 @@ def __init__(self, model_type): self.enable_shape_inference = True self.enable_gemm_fast_gelu = False - self.attention_mask_format = AttentionMaskFormat.AttentionMask + + # Set default to sequence length for BERT model to use fused attention to speed up. + # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. + self.attention_mask_format = ( + AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask + ) + + # options for stable diffusion + self.enable_group_norm = model_type == "unet" + self.enable_bias_splitgelu = model_type == "unet" + self.enable_packed_kv = model_type == "unet" def use_raw_attention_mask(self, use_raw_mask=True): if use_raw_mask: @@ -74,8 +91,14 @@ def parse(args): options.enable_gemm_fast_gelu = True if args.use_mask_index: options.use_raw_attention_mask(False) + if args.use_raw_attention_mask: + options.use_raw_attention_mask(True) if args.no_attention_mask: options.disable_attention_mask() + if args.disable_group_norm: + options.enable_group_norm = False + if args.disable_packed_kv: + options.enable_packed_kv = False return options @staticmethod @@ -164,10 +187,18 @@ def add_arguments(parser: ArgumentParser): "--use_mask_index", required=False, action="store_true", - help="use mask index instead of raw attention mask in attention operator", + help="use mask index to activate fused attention to speed up. It requires right-side padding!", ) parser.set_defaults(use_mask_index=False) + parser.add_argument( + "--use_raw_attention_mask", + required=False, + action="store_true", + help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.", + ) + parser.set_defaults(use_raw_attention_mask=False) + parser.add_argument( "--no_attention_mask", required=False, @@ -185,3 +216,19 @@ def add_arguments(parser: ArgumentParser): "MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.", ) parser.set_defaults(use_multi_head_attention=False) + + parser.add_argument( + "--disable_group_norm", + required=False, + action="store_true", + help="not fuse GroupNorm. Only works for model_type=unet", + ) + parser.set_defaults(disable_group_norm=False) + + parser.add_argument( + "--disable_packed_kv", + required=False, + action="store_true", + help="not use packed kv in cross attention. Only works for model_type=unet", + ) + parser.set_defaults(disable_packed_kv=False) diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index 865c1542c1cc9..8363f2674cd40 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -28,8 +28,8 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}") return False, input_name - def cast_input_to_int32(self, input_name: str): - cast_output = input_name + "_int32" + def cast_input(self, input_name: str, target_type="int32"): + cast_output = input_name + "_" + target_type # Avoid consequent Cast nodes. inputs = [input_name] @@ -40,11 +40,24 @@ def cast_input_to_int32(self, input_name: str): inputs = [parent_node.input[0]] cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) - cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.INT32))]) + + if target_type == "int32": + to_type = int(TensorProto.INT32) + elif target_type == "float32": + to_type = int(TensorProto.FLOAT) + elif target_type == "float16": + to_type = int(TensorProto.FLOAT16) + else: + raise ValueError("Invalid target_type: {target_type}") + + cast_node.attribute.extend([helper.make_attribute("to", to_type)]) self.model.add_node(cast_node) return cast_output, cast_node + def cast_input_to_int32(self, input_name: str): + return self.cast_input(input_name, "int32") + def remove_cast_int32(self, input_name: str): input_name_to_nodes = self.model.input_name_to_nodes() nodes = input_name_to_nodes[input_name] diff --git a/onnxruntime/python/tools/transformers/models/diffusion/__init__.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/__init__.py similarity index 100% rename from onnxruntime/python/tools/transformers/models/diffusion/__init__.py rename to onnxruntime/python/tools/transformers/models/stable_diffusion/__init__.py diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py new file mode 100755 index 0000000000000..580c5ef4c3cca --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -0,0 +1,244 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import argparse +import os +import time + +SD_MODELS = { + "1.5": "runwayml/stable-diffusion-v1-5", + "2.0": "stabilityai/stable-diffusion-2", + "2.1": "stabilityai/stable-diffusion-2-1", +} + + +def get_test_settings(): + height = 512 + width = 512 + num_inference_steps = 50 + prompts = [ + "a photo of an astronaut riding a horse on mars", + "cute grey cat with blue eyes, wearing a bowtie, acrylic painting", + "a cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital painting", + "an illustration of a house with large barn with many cute flower pots and beautiful blue sky scenery", + "one apple sitting on a table, still life, reflective, full color photograph, centered, close-up product", + "background texture of stones, masterpiece, artistic, stunning photo, award winner photo", + "new international organic style house, tropical surroundings, architecture, 8k, hdr", + "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation", + "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic", + "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k", + ] + + return height, width, num_inference_steps, prompts + + +def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_safety_checker: bool): + from diffusers import OnnxStableDiffusionPipeline + + import onnxruntime + + if directory is not None: + assert os.path.exists(directory) + session_options = onnxruntime.SessionOptions() + pipe = OnnxStableDiffusionPipeline.from_pretrained( + directory, + provider=provider, + sess_options=session_options, + ) + else: + pipe = OnnxStableDiffusionPipeline.from_pretrained( + model_name, + revision="onnx", + provider=provider, + use_auth_token=True, + ) + + if disable_safety_checker: + pipe.safety_checker = None + pipe.feature_extractor = None + + return pipe + + +def get_torch_pipeline(model_name: str, disable_channels_last: bool, disable_safety_checker: bool): + from diffusers import StableDiffusionPipeline + from torch import channels_last, float16 + + pipe = StableDiffusionPipeline.from_pretrained( + model_name, torch_dtype=float16, revision="fp16", use_auth_token=True + ).to("cuda") + + if not disable_channels_last: + pipe.unet.to(memory_format=channels_last) # in-place operation + + if disable_safety_checker: + pipe.safety_checker = None + pipe.feature_extractor = None + + return pipe + + +def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, disable_safety_checker: bool): + short_model_name = model_name.split("/")[-1].replace("stable-diffusion-", "sd") + return f"{engine}_{short_model_name}_b{batch_size}" + ("" if disable_safety_checker else "_safe") + + +def run_ort_pipeline(pipe, batch_size: int, image_filename_prefix: str): + from diffusers import OnnxStableDiffusionPipeline + + assert isinstance(pipe, OnnxStableDiffusionPipeline) + + height, width, num_inference_steps, prompts = get_test_settings() + + pipe("warm up", height, width, num_inference_steps=2) + + latency_list = [] + for i, prompt in enumerate(prompts): + input_prompts = [prompt] * batch_size + inference_start = time.time() + image = pipe(input_prompts, height, width, num_inference_steps).images[0] + inference_end = time.time() + + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency} seconds") + image.save(f"{image_filename_prefix}_{i}.jpg") + print("Average latency in seconds:", sum(latency_list) / len(latency_list)) + + +def run_torch_pipeline(pipe, batch_size: int, image_filename_prefix: str): + import torch + + height, width, num_inference_steps, prompts = get_test_settings() + + pipe("warm up", height, width, num_inference_steps=2) + + torch.set_grad_enabled(False) + + latency_list = [] + for i, prompt in enumerate(prompts): + input_prompts = [prompt] * batch_size + torch.cuda.synchronize() + inference_start = time.time() + image = pipe(input_prompts, height, width, num_inference_steps).images[0] + torch.cuda.synchronize() + inference_end = time.time() + + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency} seconds") + image.save(f"{image_filename_prefix}_{i}.jpg") + + print("Average latency in seconds:", sum(latency_list) / len(latency_list)) + + +def run_ort(model_name: str, directory: str, provider: str, batch_size: int, disable_safety_checker: bool): + load_start = time.time() + pipe = get_ort_pipeline(model_name, directory, provider, disable_safety_checker) + load_end = time.time() + print(f"Model loading took {load_end - load_start} seconds") + + image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, disable_safety_checker) + run_ort_pipeline(pipe, batch_size, image_filename_prefix) + + +def run_torch(model_name: str, batch_size: int, disable_channels_last: bool, disable_safety_checker: bool): + import torch + + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # torch.backends.cuda.matmul.allow_tf32 = True + + torch.set_grad_enabled(False) + + load_start = time.time() + pipe = get_torch_pipeline(model_name, disable_channels_last, disable_safety_checker) + load_end = time.time() + print(f"Model loading took {load_end - load_start} seconds") + + image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) + ( + "" if disable_channels_last else "_channels_last" + ) + with torch.inference_mode(): + run_torch_pipeline(pipe, batch_size, image_filename_prefix) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-e", + "--engine", + required=False, + type=str, + default="onnxruntime", + choices=["onnxruntime", "torch"], + help="Engines to benchmark. Default is onnxruntime.", + ) + + parser.add_argument( + "-v", + "--version", + required=True, + type=str, + choices=list(SD_MODELS.keys()), + help="Stable diffusion version like 1.5, 2.0 or 2.1", + ) + + parser.add_argument( + "-p", + "--pipeline", + required=False, + type=str, + default=None, + help="Directory of saved onnx pipeline. It could be output directory of optimize_pipeline.py.", + ) + + parser.add_argument( + "-c", + "--disable_channels_last", + required=False, + action="store_true", + help="Disable channels last for torch. It will be ignored for onnxruntime engine", + ) + parser.set_defaults(disable_channels_last=False) + + parser.add_argument( + "--enable_safety_checker", + required=False, + action="store_true", + help="Enable safety checker", + ) + parser.set_defaults(enable_safety_checker=False) + + parser.add_argument("-b", "--batch_size", type=int, default=1) + + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + print(args) + + sd_model = SD_MODELS[args.version] + if args.engine == "onnxruntime": + assert args.pipeline, "--pipeline should be specified for onnxruntime engine" + + if args.batch_size > 1: + # Need remove a line https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L307 + # in diffuers to run batch_size > 1. + assert ( + args.enable_safety_checker + ), "batch_size > 1 is not compatible with safety checker due to a bug in diffuers" + + provider = "CUDAExecutionProvider" # TODO: use ["CUDAExecutionProvider", "CPUExecutionProvider"] in diffuers + run_ort(sd_model, args.pipeline, provider, args.batch_size, not args.enable_safety_checker) + else: + run_torch(sd_model, args.batch_size, args.disable_channels_last, not args.enable_safety_checker) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py similarity index 52% rename from onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py rename to onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 8e20f58dd75d4..0979f0d2ddcb5 100644 --- a/onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -6,16 +6,23 @@ # This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference. # # Before running this script, you need convert checkpoint to float32 onnx models like the following -# git clone https://github.com/huggingface/diffusers -# cd diffusers -# pip install -e . +# export ONNX_ROOT=./sd_onnx +# pip install -r requirements.txt # huggingface-cli login -# python3 scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5 -# +# wget https://raw.githubusercontent.com/huggingface/diffusers/v0.12.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/stable-diffusion-v1-5-fp32 +# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path stabilityai/stable-diffusion-2-1 --output_path $ONNX_ROOT/stable-diffusion-v2-1-fp32 +# Note that this script might not be compatible with older or newer version of diffusers/transformers. It is because fusion script need change accordingly when onnx graph is changed. + # Then you can use this script to convert them to float16 like the following: -# pip3 install -U onnxruntime-gpu >= 1.14 -# python3 -m onnxruntime.transformers.models.diffusion.convert_to_fp16 -i ../stable-diffusion-v1-5 -o ../stable-diffusion-v1-5-fp16 -# Note that float16 model is intended for CUDA Execution Provider. It might not run in CPU Execution Provider. +# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 +# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16 +# Or +# pip install -U onnxruntime-gpu >= 1.14 +# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 +# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16 + +# Note that float16 model is for CUDA Execution Provider. It might not run in CPU Execution Provider. import argparse import logging @@ -27,51 +34,63 @@ import coloredlogs sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from fusion_options import FusionOptions from optimizer import optimize_model # noqa: E402 logger = logging.getLogger(__name__) -def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool): - """Convert a model to float16 +def optimize_stable_diffusion_onnx_pipeline( + source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool +): + """Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16. Args: - source_dir (Path): source directory - target_dir (Path): target directory - overwrite (bool): overwrite if exists - use_external_data_format (bool): save model to two files: one for onnx graph, another for weights + source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. + target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. + overwrite (bool): Overwrite files if exists. + use_external_data_format (bool): save onnx model to two files: one for onnx graph, another for weights + float16 (bool): use half precision Raises: RuntimeError: input onnx model does not exist RuntimeError: output onnx model path existed """ - dirs_with_onnx = ["vae_encoder", "vae_decoder", "text_encoder", "safety_checker", "unet"] + dirs_with_onnx = ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"] for name in dirs_with_onnx: onnx_model_path = source_dir / name / "model.onnx" if not os.path.exists(onnx_model_path): - raise RuntimeError(f"input onnx model does not exist: {onnx_model_path}") + message = f"input onnx model does not exist: {onnx_model_path}." + if name not in ["safety_checker", "feature_extractor"]: + raise RuntimeError(message) + continue num_heads = 0 hidden_size = 0 # Graph fusion before fp16 conversion, otherwise they cannot be fused later. # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. + logger.info(f"optimize {onnx_model_path}...") + + fusion_options = FusionOptions("unet") + # packed kv requires compute capacity >= 7.5 (like T4, A100, RTX 2060~4090. See https://developer.nvidia.com/cuda-gpus) + # Suggest to disable it if you are using older GPU like V100, RTX 1060/1070/1080, or using float32 model. + fusion_options.enable_packed_kv = float16 + m = optimize_model( str(onnx_model_path), model_type="unet", num_heads=num_heads, hidden_size=hidden_size, opt_level=0, - optimization_options=None, + optimization_options=fusion_options, use_gpu=False, ) - # VAE-decoder in fp16 reduced quality thus we exclude it here - if name != "vae_decoder": - m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize"]) - else: - print("skip convert vae_decoder to fp16.") + if float16: + logger.info("convert %s to float16 ...", name) + m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize", "GroupNorm"]) optimized_model_path = target_dir / name / "model.onnx" output_dir = optimized_model_path.parent @@ -84,11 +103,11 @@ def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_ext output_dir.mkdir(parents=True, exist_ok=True) m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format) - print(f"{onnx_model_path} => {optimized_model_path}") + logger.info("%s => %s", onnx_model_path, optimized_model_path) -def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool): - """Copy extra directory. +def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): + """Copy extra directory that does not have onnx model Args: source_dir (Path): source directory @@ -100,10 +119,15 @@ def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool): RuntimeError: output path exists but overwrite is false. """ extra_dirs = ["scheduler", "tokenizer", "feature_extractor"] + for name in extra_dirs: source_path = source_dir / name + if not os.path.exists(source_path): - raise RuntimeError(f"source path does not exist: {source_path}") + message = f"source path does not exist: {source_path}" + if name not in ["safety_checker", "feature_extractor"]: + raise RuntimeError(message) + continue target_path = target_dir / name if target_path.exists(): @@ -112,7 +136,7 @@ def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool): shutil.rmtree(target_path) shutil.copytree(source_path, target_path) - print(f"{source_path} => {target_path}") + logger.info("%s => %s", source_path, target_path) extra_files = ["model_index.json"] for name in extra_files: @@ -126,7 +150,7 @@ def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool): raise RuntimeError(f"output path existed: {target_path}") os.remove(target_path) shutil.copyfile(source_path, target_path) - print(f"{source_path} => {target_path}") + logger.info("%s => %s", source_path, target_path) def parse_arguments(): @@ -150,8 +174,16 @@ def parse_arguments(): "--output", required=True, type=str, - help="Root of output directory of stable diffusion onnx pipeline with float16 models.", + help="Root of output directory of stable diffusion onnx pipeline with optimized models.", + ) + + parser.add_argument( + "--float16", + required=False, + action="store_true", + help="Output models of half or mixed precision.", ) + parser.set_defaults(float16=False) parser.add_argument( "--overwrite", @@ -166,7 +198,8 @@ def parse_arguments(): "--use_external_data_format", required=False, action="store_true", - help="Onnx model larger than 2GB need to use external data format.", + help="Onnx model larger than 2GB need to use external data format. " + "Save onnx model to two files: one for onnx graph, another for large weights.", ) parser.set_defaults(use_external_data_format=False) @@ -177,8 +210,10 @@ def parse_arguments(): def main(): coloredlogs.install(fmt="%(funcName)20s: %(message)s") args = parse_arguments() - copy_extra(Path(args.input), Path(args.output), args.overwrite) - convert_to_fp16(Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format) + copy_extra_directory(Path(args.input), Path(args.output), args.overwrite) + optimize_stable_diffusion_onnx_pipeline( + Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format, args.float16 + ) main() diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 4827facd78100..96c22b5894c60 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -977,6 +977,10 @@ def save_model_to_file(self, output_path, use_external_data_format=False, all_te logger.info("Sort graphs in topological order") self.topological_sort() + # Note: After the model is saved to another directory with external data, + # You need reload the onnx model if you want to read tensor from self.model object. + # It is because the base directory is not updated for self.model object so attempt to read tensor data + # might encounter error since external data cannot be located. OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file) logger.info(f"Model saved to {output_path}") diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 7872cf68e7366..feba717bd8f6f 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -7,6 +7,8 @@ from typing import Optional from fusion_attention_unet import FusionAttentionUnet +from fusion_biassplitgelu import FusionBiasSplitGelu +from fusion_group_norm import FusionGroupNorm from fusion_options import FusionOptions from onnx import ModelProto from onnx_model_bert import BertOnnxModel @@ -52,11 +54,20 @@ def optimize(self, options: Optional[FusionOptions] = None): self.fuse_reshape() + if (options is None) or options.enable_group_norm: + group_norm_fusion = FusionGroupNorm(self) + group_norm_fusion.apply() + + if (options is None) or options.enable_bias_splitgelu: + bias_split_gelu_fusion = FusionBiasSplitGelu(self) + bias_split_gelu_fusion.apply() + if (options is None) or options.enable_attention: - self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False) + self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False) self_attention_fusion.apply() - cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True) + enable_packed_kv = (options is None) or options.enable_packed_kv + cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv) cross_attention_fusion.apply() if (options is None) or options.enable_skip_layer_norm: diff --git a/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc new file mode 100644 index 0000000000000..3fac765d898da --- /dev/null +++ b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +using namespace onnxruntime::test; + +namespace onnxruntime { +namespace test { +namespace bias_split_gelu_test { +std::vector ComputeGelu(const std::vector& input_data) { + std::vector output; + output.reserve(input_data.size()); + + for (size_t i = 0; i < input_data.size(); i++) { + float x = input_data[i]; + float y = x * (0.5f * (1.0f + std::erff(x / 1.41421356237f))); + output.push_back(y); + } + return output; +} + +std::vector AddBias(const std::vector& input_data, const std::vector& bias_data) { + size_t bias_length = bias_data.size(); + + std::vector output; + output.reserve(input_data.size()); + + for (size_t i = 0; i < input_data.size(); i++) { + output.push_back(input_data[i] + bias_data[i % bias_length]); + } + return output; +} + +void Split(const std::vector& input_data, + const std::vector& input_dims, + std::vector& left_half_data, std::vector& right_half_data) { + std::size_t length = input_data.size(); + left_half_data.reserve(length / 2); + right_half_data.reserve(length / 2); + + int64_t index = 0; + for (int64_t i = 0; i < input_dims[0]; i++) { + for (int64_t j = 0; j < input_dims[1]; j++) { + for (int64_t k = 0; k < input_dims[2]; k++, index++) { + if (k < input_dims[2] / 2) { + left_half_data.push_back(input_data[index]); + } else { + right_half_data.push_back(input_data[index]); + } + } + } + } +} + +std::vector GetExpectedResult(const std::vector& input_data, + const std::vector& input_dims, + const std::vector& bias_data) { + std::vector add_bias_data = AddBias(input_data, bias_data); + std::vector left_half_data; + std::vector right_half_data; + Split(add_bias_data, input_dims, left_half_data, right_half_data); + std::vector right_gelu_data = ComputeGelu(right_half_data); + + std::vector output_data; + output_data.reserve(left_half_data.size()); + for (std::size_t i = 0; i < left_half_data.size(); i++) { + output_data.push_back(left_half_data[i] * right_gelu_data[i]); + } + return output_data; +} +} // namespace bias_split_gelu_test + +#if defined(USE_CUDA) // The operator has only CUDA implementation right now + +static void RunBiasSplitGeluGpuTest(const std::vector& input_data, + const std::vector& bias_data, + const std::vector& output_data, + const std::vector& input_dims, + const std::vector& bias_dims, + const std::vector& output_dims, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + if (!HasCudaEnvironment(min_cuda_architecture)) { + return; + } + + OpTester tester("BiasSplitGelu", 1, onnxruntime::kMSDomain); + + if (use_float16) { + tester.AddInput("X", input_dims, ToFloat16(input_data)); + tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); + tester.AddOutput("Y", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("X", input_dims, input_data); + tester.AddInput("bias", bias_dims, bias_data); + tester.AddOutput("Y", output_dims, output_data); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +static void RunBiasSplitGeluTest(int64_t batch_size, int64_t sequence_length, int64_t hidden_size) { + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector bias_dims = {hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size / 2}; + + RandomValueGenerator random{}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 0.3f); + std::vector bias_data = random.Gaussian(bias_dims, 0.0f, 0.3f); + std::vector output_data = bias_split_gelu_test::GetExpectedResult(input_data, input_dims, bias_data); + + RunBiasSplitGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims); +} + +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_2560) { + constexpr int64_t batch_size = 2; + constexpr int64_t sequence_length = 5; + constexpr int64_t hidden_size = 2560; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_5120) { + constexpr int64_t batch_size = 2; + constexpr int64_t sequence_length = 1; + constexpr int64_t hidden_size = 5120; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_10240) { + constexpr int64_t batch_size = 1; + constexpr int64_t sequence_length = 2; + constexpr int64_t hidden_size = 10240; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/group_norm_op_test.cc b/onnxruntime/test/contrib_ops/group_norm_op_test.cc new file mode 100644 index 0000000000000..4af51e24159ef --- /dev/null +++ b/onnxruntime/test/contrib_ops/group_norm_op_test.cc @@ -0,0 +1,436 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/provider_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace std; + +namespace onnxruntime { +namespace test { + +TEST(GroupNormTest, GroupNorm_128) { + constexpr int64_t B = 2; + constexpr int64_t C = 128; + constexpr int64_t H = 2; + constexpr int64_t W = 2; + + std::vector dims{B, H, W, C}; + std::vector input_data = { + 0.696469f, 0.719469f, 0.480932f, 0.438572f, 0.182492f, 0.634401f, 0.722443f, 0.293714f, 0.430863f, 0.426351f, + 0.623953f, 0.866309f, 0.519485f, 0.603060f, 0.417022f, 0.669314f, 0.842342f, 0.194223f, 0.627249f, 0.556785f, + 0.318766f, 0.925132f, 0.304768f, 0.355915f, 0.151127f, 0.513128f, 0.321981f, 0.854452f, 0.171082f, 0.578551f, + 0.905342f, 0.901911f, 0.806969f, 0.600699f, 0.428347f, 0.093327f, 0.457412f, 0.708697f, 0.286537f, 0.664872f, + 0.438214f, 0.582671f, 0.750717f, 0.859389f, 0.081780f, 0.562218f, 0.467988f, 0.931932f, 0.378986f, 0.032198f, + 0.542636f, 0.769397f, 0.661168f, 0.425868f, 0.181629f, 0.417291f, 0.918397f, 0.313669f, 0.238250f, 0.301947f, + 0.005545f, 0.097038f, 0.798923f, 0.715601f, 0.650750f, 0.502071f, 0.374292f, 0.300610f, 0.005943f, 0.597433f, + 0.887593f, 0.045895f, 0.710162f, 0.355958f, 0.832716f, 0.016392f, 0.225498f, 0.568103f, 0.298245f, 0.587494f, + 0.003532f, 0.052091f, 0.026611f, 0.607529f, 0.389874f, 0.937668f, 0.273842f, 0.882276f, 0.422543f, 0.145264f, + 0.564570f, 0.278024f, 0.542989f, 0.232686f, 0.820574f, 0.332580f, 0.688374f, 0.675035f, 0.944366f, 0.210653f, + 0.456271f, 0.909715f, 0.444221f, 0.947120f, 0.206132f, 0.509402f, 0.322974f, 0.655923f, 0.263610f, 0.953818f, + 0.423518f, 0.020576f, 0.523488f, 0.410266f, 0.013160f, 0.467330f, 0.652154f, 0.165560f, 0.349519f, 0.645823f, + 0.089832f, 0.051901f, 0.810513f, 0.902557f, 0.882713f, 0.212831f, 0.893865f, 0.916849f, 0.286139f, 0.423106f, + 0.392118f, 0.059678f, 0.175452f, 0.849432f, 0.322959f, 0.630976f, 0.493685f, 0.893389f, 0.115618f, 0.250455f, + 0.612895f, 0.545068f, 0.681301f, 0.585937f, 0.083195f, 0.572457f, 0.723416f, 0.158960f, 0.691970f, 0.841670f, + 0.398186f, 0.762548f, 0.398876f, 0.666625f, 0.661564f, 0.384838f, 0.829113f, 0.521533f, 0.207636f, 0.983631f, + 0.394370f, 0.865864f, 0.204543f, 0.296861f, 0.753526f, 0.839243f, 0.306470f, 0.887857f, 0.765096f, 0.814844f, + 0.574064f, 0.821504f, 0.138416f, 0.122244f, 0.807938f, 0.582175f, 0.668384f, 0.744781f, 0.066774f, 0.573774f, + 0.049097f, 0.788187f, 0.321319f, 0.989035f, 0.091296f, 0.047340f, 0.807791f, 0.980582f, 0.484909f, 0.461909f, + 0.798846f, 0.410520f, 0.865460f, 0.067449f, 0.214012f, 0.634442f, 0.365719f, 0.293152f, 0.016119f, 0.710999f, + 0.958510f, 0.929764f, 0.846055f, 0.721184f, 0.875125f, 0.225463f, 0.418627f, 0.948252f, 0.480889f, 0.406779f, + 0.920149f, 0.811953f, 0.754797f, 0.908011f, 0.206115f, 0.822304f, 0.245033f, 0.092186f, 0.191336f, 0.741760f, + 0.693985f, 0.746698f, 0.464935f, 0.953697f, 0.204304f, 0.006028f, 0.491190f, 0.421200f, 0.279802f, 0.043418f, + 0.036323f, 0.617660f, 0.165066f, 0.296902f, 0.972098f, 0.405653f, 0.271480f, 0.102880f, 0.371992f, 0.918097f, + 0.109088f, 0.661717f, 0.024148f, 0.375109f, 0.232980f, 0.612682f, 0.277424f, 0.038700f, 0.648450f, 0.294307f, + 0.131115f, 0.222157f, 0.919472f, 0.392304f, 0.496508f, 0.517623f, 0.226851f, 0.980764f, 0.343178f, 0.398044f, + 0.531551f, 0.724455f, 0.361789f, 0.092105f, 0.425830f, 0.944160f, 0.317285f, 0.483034f, 0.120629f, 0.342764f, + 0.875457f, 0.624904f, 0.763683f, 0.095713f, 0.016129f, 0.153071f, 0.554383f, 0.357398f, 0.704959f, 0.593177f, + 0.240856f, 0.105908f, 0.846506f, 0.316788f, 0.338671f, 0.002688f, 0.292489f, 0.257542f, 0.731073f, 0.983522f, + 0.450636f, 0.927584f, 0.741862f, 0.165938f, 0.665261f, 0.696311f, 0.565642f, 0.337066f, 0.751644f, 0.909872f, + 0.399379f, 0.201400f, 0.007426f, 0.206096f, 0.029320f, 0.472913f, 0.653365f, 0.102635f, 0.792299f, 0.411569f, + 0.845533f, 0.236600f, 0.463653f, 0.241686f, 0.894978f, 0.539505f, 0.988329f, 0.963004f, 0.208248f, 0.191007f, + 0.025242f, 0.993033f, 0.105446f, 0.281235f, 0.533886f, 0.632050f, 0.126958f, 0.971046f, 0.429813f, 0.148778f, + 0.123923f, 0.007738f, 0.363576f, 0.572147f, 0.453089f, 0.556035f, 0.927455f, 0.372396f, 0.680903f, 0.335544f, + 0.369291f, 0.348797f, 0.336340f, 0.709623f, 0.117398f, 0.602932f, 0.676906f, 0.559738f, 0.912132f, 0.777769f, + 0.779767f, 0.657815f, 0.470689f, 0.087408f, 0.270176f, 0.218035f, 0.932892f, 0.707115f, 0.040683f, 0.368875f, + 0.361817f, 0.950252f, 0.987351f, 0.257348f, 0.398639f, 0.625209f, 0.868315f, 0.864480f, 0.093427f, 0.943201f, + 0.709386f, 0.542860f, 0.774580f, 0.238783f, 0.998918f, 0.760210f, 0.732601f, 0.451088f, 0.612179f, 0.000082f, + 0.415504f, 0.851548f, 0.426096f, 0.804026f, 0.551315f, 0.684830f, 0.729050f, 0.737995f, 0.531828f, 0.611024f, + 0.228263f, 0.433701f, 0.312261f, 0.501837f, 0.414826f, 0.985560f, 0.826341f, 0.304121f, 0.510422f, 0.674689f, + 0.243666f, 0.885327f, 0.594432f, 0.695530f, 0.388951f, 0.043591f, 0.995358f, 0.691702f, 0.343456f, 0.130895f, + 0.553257f, 0.354265f, 0.552370f, 0.988345f, 0.520010f, 0.564359f, 0.161069f, 0.079366f, 0.547764f, 0.569004f, + 0.048579f, 0.780998f, 0.111392f, 0.440328f, 0.084904f, 0.927577f, 0.079149f, 0.128631f, 0.424307f, 0.811644f, + 0.551593f, 0.717758f, 0.635900f, 0.121754f, 0.996086f, 0.699834f, 0.518717f, 0.481026f, 0.186904f, 0.916832f, + 0.502216f, 0.095530f, 0.043223f, 0.626309f, 0.375186f, 0.341831f, 0.443368f, 0.967494f, 0.266906f, 0.236462f, + 0.232480f, 0.362277f, 0.162016f, 0.026197f, 0.777162f, 0.871683f, 0.872879f, 0.940029f, 0.596487f, 0.084822f, + 0.539960f, 0.660952f, 0.932351f, 0.500561f, 0.198366f, 0.857153f, 0.904226f, 0.349566f, 0.242220f, 0.634638f, + 0.327100f, 0.959345f, 0.301053f, 0.364187f, 0.215505f, 0.334836f, 0.580713f, 0.200401f, 0.237478f, 0.772878f, + 0.808964f, 0.346795f, 0.360424f, 0.845753f, 0.314351f, 0.483889f, 0.332754f, 0.611977f, 0.863353f, 0.815966f, + 0.408660f, 0.082653f, 0.184886f, 0.441697f, 0.280477f, 0.276902f, 0.837466f, 0.245131f, 0.924552f, 0.858917f, + 0.134613f, 0.704779f, 0.040616f, 0.230090f, 0.678095f, 0.287103f, 0.988215f, 0.980597f, 0.744615f, 0.127612f, + 0.305646f, 0.857652f, 0.922382f, 0.441324f, 0.617186f, 0.220603f, 0.779245f, 0.616006f, 0.135673f, 0.247513f, + 0.359867f, 0.546479f, 0.510376f, 0.167482f, 0.647433f, 0.875771f, 0.795605f, 0.284549f, 0.648163f, 0.697942f, + 0.717354f, 0.792651f, 0.402787f, 0.663393f, 0.701360f, 0.468060f, 0.376677f, 0.475468f, 0.298579f, 0.981118f, + 0.880607f, 0.326968f, 0.355065f, 0.336230f, 0.098184f, 0.016991f, 0.453990f, 0.115745f, 0.207050f, 0.163338f, + 0.587616f, 0.518773f, 0.952489f, 0.803843f, 0.844077f, 0.264328f, 0.097160f, 0.338377f, 0.995782f, 0.945237f, + 0.879142f, 0.501190f, 0.668073f, 0.043569f, 0.550953f, 0.663043f, 0.278687f, 0.244660f, 0.747326f, 0.768959f, + 0.756060f, 0.915355f, 0.280000f, 0.113509f, 0.430876f, 0.755474f, 0.205838f, 0.225924f, 0.265096f, 0.899392f, + 0.330582f, 0.158679f, 0.684826f, 0.544763f, 0.387195f, 0.921920f, 0.383194f, 0.199158f, 0.220731f, 0.083348f, + 0.267982f, 0.416437f, 0.503247f, 0.229764f, 0.751615f, 0.979886f, 0.218682f, 0.785557f, 0.596404f, 0.673936f, + 0.045040f, 0.842387f, 0.478416f, 0.619851f, 0.758625f, 0.557799f, 0.428663f, 0.350140f, 0.081201f, 0.896426f, + 0.967501f, 0.668001f, 0.700871f, 0.894878f, 0.453728f, 0.157534f, 0.986580f, 0.426672f, 0.820969f, 0.739923f, + 0.352277f, 0.581123f, 0.889559f, 0.249954f, 0.012738f, 0.481115f, 0.453346f, 0.241041f, 0.520494f, 0.326236f, + 0.639721f, 0.954378f, 0.538658f, 0.690915f, 0.081894f, 0.499936f, 0.572204f, 0.182921f, 0.706266f, 0.645798f, + 0.303381f, 0.932843f, 0.404739f, 0.322655f, 0.522892f, 0.058939f, 0.563665f, 0.524866f, 0.797733f, 0.861912f, + 0.756946f, 0.534076f, 0.037392f, 0.520718f, 0.491976f, 0.965886f, 0.858428f, 0.805397f, 0.715750f, 0.242962f, + 0.121840f, 0.549413f, 0.707581f, 0.625907f, 0.103884f, 0.967437f, 0.941807f, 0.750748f, 0.391316f, 0.179390f, + 0.954144f, 0.995861f, 0.943181f, 0.225535f, 0.365521f, 0.952603f, 0.655552f, 0.984128f, 0.967362f, 0.764658f, + 0.498658f, 0.382370f, 0.076204f, 0.943615f, 0.206783f, 0.774136f, 0.219836f, 0.290086f, 0.063939f, 0.209334f, + 0.172612f, 0.684041f, 0.813314f, 0.710075f, 0.069982f, 0.338582f, 0.209592f, 0.618762f, 0.537080f, 0.754518f, + 0.657660f, 0.775365f, 0.624964f, 0.544813f, 0.650043f, 0.851819f, 0.127388f, 0.513679f, 0.920330f, 0.419923f, + 0.486112f, 0.347025f, 0.555860f, 0.550530f, 0.693655f, 0.966579f, 0.293974f, 0.196309f, 0.675409f, 0.918160f, + 0.348893f, 0.196346f, 0.473992f, 0.668433f, 0.455520f, 0.089096f, 0.405057f, 0.970099f, 0.672699f, 0.614172f, + 0.233294f, 0.329279f, 0.718766f, 0.744805f, 0.732767f, 0.195352f, 0.845798f, 0.223270f, 0.112540f, 0.858727f, + 0.458333f, 0.753204f, 0.021647f, 0.119070f, 0.378121f, 0.015745f, 0.458821f, 0.738294f, 0.802076f, 0.364342f, + 0.452341f, 0.350539f, 0.763269f, 0.449212f, 0.404651f, 0.508437f, 0.239293f, 0.483217f, 0.315162f, 0.086802f, + 0.146036f, 0.347146f, 0.495040f, 0.036045f, 0.104871f, 0.805327f, 0.475591f, 0.858913f, 0.339811f, 0.397564f, + 0.992478f, 0.147723f, 0.033954f, 0.661169f, 0.727080f, 0.537663f, 0.627922f, 0.567574f, 0.110105f, 0.385743f, + 0.760046f, 0.035033f, 0.441879f, 0.432969f, 0.852450f, 0.733128f, 0.040908f, 0.465148f, 0.525712f, 0.027543f, + 0.959939f, 0.457182f, 0.666527f, 0.031669f, 0.908842f, 0.539977f, 0.656343f, 0.466810f, 0.461138f, 0.658768f, + 0.944778f, 0.801277f, 0.274225f, 0.808626f, 0.764664f, 0.227802f, 0.657667f, 0.106055f, 0.328335f, 0.770169f, + 0.481128f, 0.905028f, 0.271492f, 0.476027f, 0.611671f, 0.727043f, 0.733395f, 0.594644f, 0.898713f, 0.196084f, + 0.859941f, 0.294517f, 0.519280f, 0.563628f, 0.251777f, 0.501324f, 0.897753f, 0.246321f, 0.324222f, 0.585902f, + 0.554412f, 0.174032f, 0.936472f, 0.827655f, 0.987936f, 0.114385f, 0.947582f, 0.246243f, 0.324910f, 0.391096f, + 0.014144f, 0.268021f, 0.689953f, 0.063691f, 0.828527f, 0.860373f, 0.081199f, 0.311536f, 0.647020f, 0.959900f, + 0.587540f, 0.239769f, 0.393420f, 0.952011f, 0.649501f, 0.701122f, 0.654753f, 0.098328f, 0.019756f, 0.307255f, + 0.101182f, 0.903178f, 0.662636f, 0.183807f, 0.383673f, 0.268124f, 0.722163f, 0.242447f, 0.870546f, 0.520290f, + 0.535141f, 0.449352f, 0.382109f, 0.030094f, 0.014841f, 0.754523f, 0.398138f, 0.080007f, 0.994005f, 0.343086f, + 0.416415f, 0.497471f, 0.518243f, 0.594622f, 0.404539f, 0.024741f, 0.205798f, 0.463358f, 0.634085f, 0.004168f, + 0.288890f, 0.318634f, 0.649971f, 0.068623f, 0.011161f, 0.617764f, 0.595074f, 0.477778f, 0.098851f, 0.284219f, + 0.982623f, 0.378369f, 0.671127f, 0.716803f, 0.038332f, 0.175828f, 0.817099f, 0.248624f, 0.526941f, 0.143601f, + 0.318435f, 0.884003f, 0.956312f, 0.605227f, 0.516111f, 0.434986f, 0.446248f, 0.031918f, 0.876705f, 0.222946f, + 0.192030f, 0.151730f, 0.162001f, 0.931703f, 0.647385f, 0.263281f, 0.684891f, 0.196009f, 0.621328f, 0.875460f, + 0.116971f, 0.164779f, 0.810315f, 0.589415f, 0.584904f, 0.002092f, 0.368053f, 0.440462f, 0.466850f, 0.443596f, + 0.484220f, 0.870371f, 0.847502f, 0.015016f, 0.994610f, 0.624150f, 0.620991f, 0.027341f, 0.103521f, 0.971364f, + 0.694315f, 0.886678f, 0.523881f, 0.597125f, 0.947067f, 0.385271f, 0.754392f, 0.835389f, 0.975671f, 0.904114f, + 0.223580f, 0.351703f, 0.835343f, 0.052580f, 0.841164f, 0.205350f, 0.100214f, 0.310509f, 0.847647f, 0.990239f, + 0.434309f, 0.485149f, 0.367266f, 0.977029f, 0.723466f, 0.941467f, 0.249746f, 0.492914f, 0.584139f, 0.015198f, + 0.812326f, 0.527457f, 0.871326f, 0.821721f, 0.101746f, 0.594467f, 0.365567f, 0.751121f, 0.516166f, 0.369039f, + 0.557870f, 0.081583f, 0.060740f, 0.194498f, 0.932089f, 0.673928f, 0.694386f, 0.498688f, 0.422973f, 0.039913f, + 0.051126f, 0.339099f, 0.163220f, 0.351669f, 0.727191f, 0.116125f, 0.363897f, 0.637357f, 0.432239f, 0.345904f, + 0.623269f, 0.016948f, 0.826530f, 0.308751f, 0.290656f, 0.058387f, 0.264397f, 0.294895f, 0.639992f, 0.489059f, + 0.343698f, 0.929770f, 0.390125f, 0.397707f}; + + std::vector gamma_data = { + 0.447359f, 0.873295f, 0.351357f, 0.065158f, 0.442673f, 0.998459f, 0.379773f, 0.193055f, 0.045130f, 0.170969f, + 0.324064f, 0.574278f, 0.665588f, 0.042819f, 0.936180f, 0.235638f, 0.149062f, 0.530829f, 0.677586f, 0.307253f, + 0.669441f, 0.294294f, 0.902172f, 0.880695f, 0.071194f, 0.150403f, 0.698059f, 0.000120f, 0.821814f, 0.356240f, + 0.744620f, 0.044237f, 0.209264f, 0.070805f, 0.179824f, 0.384421f, 0.491552f, 0.916091f, 0.627174f, 0.706480f, + 0.082111f, 0.286787f, 0.991732f, 0.560422f, 0.787817f, 0.032482f, 0.084076f, 0.109233f, 0.015286f, 0.921979f, + 0.253635f, 0.996569f, 0.738130f, 0.250611f, 0.991805f, 0.868534f, 0.164998f, 0.185322f, 0.680186f, 0.078280f, + 0.584525f, 0.066603f, 0.221298f, 0.948440f, 0.498572f, 0.573713f, 0.269683f, 0.440062f, 0.133002f, 0.516616f, + 0.053956f, 0.048249f, 0.679648f, 0.054982f, 0.521284f, 0.266026f, 0.187694f, 0.573319f, 0.296463f, 0.456382f, + 0.138974f, 0.126486f, 0.106529f, 0.071560f, 0.553714f, 0.756005f, 0.792367f, 0.957845f, 0.168392f, 0.135619f, + 0.469955f, 0.861008f, 0.767069f, 0.558178f, 0.156783f, 0.391263f, 0.719346f, 0.373413f, 0.039119f, 0.583884f, + 0.720135f, 0.714771f, 0.164866f, 0.335992f, 0.409172f, 0.420481f, 0.114158f, 0.385532f, 0.506632f, 0.710561f, + 0.569448f, 0.404931f, 0.927597f, 0.598084f, 0.974791f, 0.867376f, 0.673626f, 0.899313f, 0.991240f, 0.220877f, + 0.691057f, 0.918779f, 0.017400f, 0.799489f, 0.089403f, 0.916554f, 0.612013f, 0.162069f}; + + std::vector beta_data = { + 0.039410f, 0.827821f, 0.139492f, 0.939541f, 0.090865f, 0.837978f, 0.423533f, 0.872735f, 0.768574f, 0.852882f, + 0.470242f, 0.713768f, 0.318668f, 0.047173f, 0.232400f, 0.001362f, 0.363028f, 0.493829f, 0.019407f, 0.007730f, + 0.686464f, 0.100436f, 0.073846f, 0.495598f, 0.718159f, 0.977165f, 0.295397f, 0.117518f, 0.068537f, 0.207511f, + 0.100055f, 0.003384f, 0.285074f, 0.164207f, 0.018250f, 0.354632f, 0.825916f, 0.303662f, 0.710100f, 0.728735f, + 0.025556f, 0.961785f, 0.139009f, 0.717465f, 0.379443f, 0.868223f, 0.994961f, 0.193323f, 0.819456f, 0.505503f, + 0.965431f, 0.658089f, 0.593238f, 0.229523f, 0.718700f, 0.288201f, 0.845759f, 0.977264f, 0.007793f, 0.954633f, + 0.358460f, 0.488316f, 0.924086f, 0.775958f, 0.243222f, 0.096853f, 0.841226f, 0.747060f, 0.858339f, 0.384041f, + 0.492114f, 0.465019f, 0.314722f, 0.335672f, 0.718649f, 0.753071f, 0.863854f, 0.844902f, 0.753938f, 0.332778f, + 0.710046f, 0.972624f, 0.916240f, 0.971488f, 0.036208f, 0.611599f, 0.215343f, 0.246560f, 0.844061f, 0.750192f, + 0.328802f, 0.519915f, 0.188330f, 0.003827f, 0.899958f, 0.709642f, 0.528818f, 0.054099f, 0.420840f, 0.380042f, + 0.171547f, 0.156188f, 0.173178f, 0.596836f, 0.124704f, 0.238549f, 0.946272f, 0.219462f, 0.763857f, 0.598040f, + 0.413157f, 0.595286f, 0.133620f, 0.484188f, 0.972134f, 0.427721f, 0.242881f, 0.927507f, 0.610774f, 0.727857f, + 0.543405f, 0.011202f, 0.755700f, 0.978697f, 0.716188f, 0.808757f, 0.851587f, 0.999201f}; + + std::vector norm_data = { + 0.406306f, 1.632045f, 0.095849f, 0.919355f, -0.458834f, 1.632483f, 0.876482f, 0.729815f, 0.750835f, + 0.782631f, 0.590117f, 1.476163f, 0.183714f, 0.057787f, -0.474648f, 0.143954f, 0.561618f, 0.031635f, + 0.426744f, 0.118848f, 0.054676f, 0.526575f, -0.827396f, -0.206514f, 0.631899f, 1.033381f, -0.028056f, + 0.117742f, -0.928939f, 0.254703f, 1.002641f, 0.056505f, 0.502409f, 0.186869f, -0.032152f, -0.201724f, + 0.683548f, 0.900928f, 0.126877f, 1.073324f, -0.017409f, 0.957481f, 0.710492f, 1.254686f, -0.620889f, + 0.882544f, 1.003820f, 0.385277f, 0.814893f, -0.841305f, 1.028838f, 1.664626f, 0.982238f, 0.150513f, + -0.461095f, -0.012286f, 1.094831f, 0.900296f, -0.437987f, 0.919201f, -0.604762f, 0.398245f, 1.126501f, + 1.388226f, 0.740287f, 0.352386f, 0.833504f, 0.614170f, 0.687727f, 0.626510f, 0.563813f, 0.408836f, + 0.651389f, 0.307533f, 1.158524f, 0.360064f, 0.588918f, 0.904664f, 0.418446f, 0.420879f, 0.495571f, + 0.796672f, 0.759542f, 0.996513f, -0.328335f, 1.636925f, -0.644444f, 1.350502f, 0.891792f, 0.600690f, + 0.795602f, 0.142066f, -0.015730f, -0.867947f, 1.039989f, 0.261774f, 1.182381f, 0.375100f, 0.493101f, + -0.112225f, 0.136779f, 1.225890f, 0.158450f, 1.142486f, -0.296101f, 0.228868f, 0.873088f, 0.397857f, + 0.432766f, 1.815673f, 0.353312f, -0.006854f, 0.251850f, 0.343477f, -0.497336f, 0.382225f, 0.758787f, + 0.117172f, 0.342274f, 0.892228f, -0.293386f, -1.206122f, 0.772336f, 1.964310f, 0.807267f, -0.553660f, + 1.500599f, 1.184999f, -0.397960f, 0.498094f, -0.040874f, 0.811189f, -0.472885f, 2.600490f, 0.192458f, + 1.023374f, 0.762038f, 1.098150f, -0.060817f, 0.078648f, 0.518953f, 0.044398f, 0.859423f, 0.038016f, + 0.176986f, 0.714081f, 0.648229f, -0.296623f, 1.040141f, 0.429690f, -0.494966f, 1.206059f, 0.709146f, + 1.134488f, 1.010104f, 0.117495f, 0.857719f, 0.187595f, -0.713799f, 0.068448f, 0.201653f, 0.252268f, + -0.172338f, 0.070818f, 1.228964f, 1.349056f, 0.173722f, 1.663625f, 0.077027f, 1.191751f, 0.094092f, + 1.179984f, -0.462022f, 0.831658f, 1.105588f, 0.249245f, 0.829719f, 1.360636f, 0.624319f, 1.011229f, + -0.634975f, 0.475544f, 0.034839f, 1.765260f, 0.660444f, 0.743193f, 0.795097f, 1.088295f, 0.300263f, + 0.476736f, 1.126446f, 0.453643f, 1.137416f, -0.572652f, 0.673148f, 1.159168f, 0.829472f, 0.160861f, + 0.424527f, 0.503895f, 1.131327f, 0.397239f, 1.178295f, 0.893187f, 1.147333f, 0.005007f, 0.581892f, + 1.174909f, 0.703488f, 0.937277f, 1.057870f, 1.042361f, 0.414783f, 1.554468f, -0.841804f, 1.139242f, + 0.742398f, 0.564714f, -0.081046f, 2.137630f, 0.467942f, 0.330163f, 0.807147f, 1.276604f, -0.094400f, + -0.540889f, 0.428099f, 0.338536f, -0.296175f, -0.883684f, -0.070659f, 0.765356f, -0.351806f, -0.067355f, + 1.118756f, 0.077982f, 0.446440f, -0.258010f, 0.252682f, 1.239576f, -0.979634f, 0.825275f, -0.463020f, + 0.125961f, -0.208515f, 1.494655f, 0.097464f, 0.432844f, 0.867343f, -0.536458f, 0.736790f, 0.328702f, + 0.819557f, 0.061518f, 0.591131f, 0.943027f, -0.514167f, 2.631821f, -0.116213f, 0.907785f, 0.237840f, + 2.037882f, 0.258945f, 0.554331f, 0.749937f, 1.132450f, 0.197422f, 0.606424f, -1.247748f, -0.002311f, + 1.839518f, 0.087527f, 0.521764f, -0.146106f, -0.980738f, -0.302773f, 0.676835f, -0.132461f, 0.596699f, + 0.617694f, 0.659876f, 0.765150f, 1.575500f, 0.117460f, -0.473908f, -0.423069f, -0.505049f, -0.037672f, + 0.447086f, 0.281287f, -0.018190f, 0.915389f, 1.207481f, -0.962214f, 1.016921f, 1.156551f, 0.019404f, + 0.709657f, 0.713726f, 1.354227f, 0.270004f, 0.840813f, 0.865947f, 0.102975f, 0.796979f, 0.520542f, + 1.122967f, -0.562413f, 1.328713f, 0.137686f, 1.895931f, -0.574054f, 0.856002f, 0.857834f, 0.983861f, + 0.978393f, 1.250703f, 0.584533f, 0.704302f, -0.218810f, -0.416660f, 1.397336f, 0.564530f, 0.582539f, + 0.895726f, 0.679485f, 0.442242f, 0.541062f, 0.109609f, 0.275143f, 0.107930f, 0.353517f, 0.707609f, + 0.915281f, 0.628682f, 0.355126f, 0.897993f, 0.923647f, 0.977992f, 0.935513f, -0.370250f, -0.000332f, + -0.462323f, 0.742310f, 0.634979f, 0.910902f, 1.059454f, 1.354346f, 1.166715f, 0.402587f, 1.013271f, + 0.793169f, 0.608214f, -0.429466f, 0.396397f, -0.096419f, 1.306138f, 0.732527f, -0.068210f, 0.480573f, + -0.084915f, 0.843406f, 1.124528f, -0.111570f, 0.667385f, 1.014872f, 1.221989f, 1.165116f, -1.026174f, + 1.364619f, 1.676924f, 0.592108f, 1.041303f, 0.342757f, 2.547432f, 0.978781f, 1.042198f, -0.103338f, + 0.761959f, -0.205143f, 0.651057f, 1.635670f, 0.429972f, 1.116617f, 0.121796f, 1.499508f, 0.477808f, + 1.004834f, 0.238391f, 1.527244f, 0.030314f, 0.851662f, 0.729685f, 0.833627f, 0.322326f, 1.746771f, + 1.284994f, -0.011233f, -0.003167f, 0.150784f, 0.258291f, 1.278590f, 0.351162f, 0.263747f, 0.240001f, + -0.496733f, 1.630098f, 0.959952f, 0.691867f, 0.781609f, 0.678993f, 0.117479f, 0.106319f, 0.737017f, + 0.054679f, 0.007170f, 0.031594f, 0.058290f, 0.042648f, 0.435231f, -0.069486f, 1.149115f, -0.284731f, + 0.478892f, -0.119480f, 1.305503f, -1.632824f, -0.186219f, 0.339929f, 0.911391f, 1.028848f, 0.301977f, + 0.828055f, -0.564568f, 1.414307f, 1.432279f, 0.605853f, 0.199995f, -0.442368f, 1.540784f, 0.876256f, + 0.771619f, -0.860229f, 1.000022f, 0.093107f, 0.450905f, 0.872359f, 2.159872f, 0.030324f, -0.212928f, + 0.691624f, 0.714844f, 0.749217f, -0.247668f, 0.546163f, 0.526861f, 0.965844f, 0.398844f, 0.808378f, + 0.411826f, 0.859227f, 1.148453f, 1.279391f, 0.239180f, 0.580433f, 1.115814f, 1.052553f, 0.938658f, + -0.629015f, 0.794400f, -0.489248f, 1.621988f, 0.789545f, 0.749079f, -0.024277f, 0.386545f, 0.105109f, + -0.943201f, 0.658228f, 0.981167f, 1.500447f, -0.074319f, 0.409342f, 1.247461f, -0.211410f, 0.188935f, + 0.095841f, 0.758850f, 0.595416f, 0.656214f, 0.905517f, -0.334851f, 0.295979f, 0.567667f, 0.073956f, + 0.349117f, 1.184909f, 0.027066f, 2.348872f, 1.470366f, -0.435509f, 1.778383f, -0.706661f, 0.577661f, + 0.928943f, -0.556358f, 0.781633f, 2.151912f, 0.761094f, -0.845767f, 0.154289f, 1.149119f, 0.798352f, + 0.738550f, 0.334614f, 0.879107f, 0.544361f, 1.241048f, -0.116848f, 0.680143f, 0.749525f, 0.903967f, + 0.521876f, 0.044571f, 0.636013f, 0.101543f, 1.160492f, -0.183245f, 0.374511f, 0.647929f, 0.272462f, + 0.221596f, 0.477392f, 0.293685f, 0.793616f, 0.434287f, 0.685965f, 0.952972f, -0.181046f, 0.117678f, + 1.207595f, -0.074850f, -0.407813f, -0.030066f, 0.048566f, 0.067439f, 0.001379f, -0.060172f, 0.280818f, + -0.846823f, 0.816884f, 0.685541f, 0.148579f, 1.228496f, 1.213856f, 0.082768f, -0.675884f, 0.850093f, + 1.127087f, 0.347097f, 0.837999f, 0.524987f, 1.104285f, -0.759169f, 0.635460f, 0.345593f, -0.202532f, + -0.625566f, 0.947332f, 1.108524f, 0.451968f, 1.059174f, -0.345810f, 0.363634f, 0.791247f, 1.440297f, + -0.205800f, -0.385072f, 0.646567f, 1.271695f, 0.794720f, -0.208087f, 0.540560f, 0.482087f, -0.006379f, + 0.408202f, 0.465379f, 0.459783f, 0.691498f, 0.068571f, 0.526139f, 0.197028f, 0.714096f, 0.845790f, + 1.019175f, 1.102269f, -0.858393f, 1.114611f, 0.139467f, 0.453660f, 0.601039f, 0.970108f, 0.433405f, + 1.179632f, 1.010146f, 0.152422f, 0.860524f, 0.488055f, -0.402708f, 0.493756f, 0.475332f, 0.663021f, + 0.781955f, 1.281084f, 0.160415f, 0.198334f, 0.888878f, 0.253592f, 1.097622f, 0.628367f, 0.510147f, + 0.872970f, 1.314789f, 0.231785f, -1.853665f, 0.667719f, 1.129714f, -0.395025f, 0.606279f, 0.720905f, + 1.613478f, 1.226488f, 1.121590f, 1.437285f, 0.732910f, 1.502128f, 0.744201f, -0.186752f, 1.338321f, + 1.092864f, -0.237060f, 2.343516f, 0.055512f, 0.903969f, 0.112821f, -0.874084f, 0.501252f, 0.883760f, + 0.825862f, 1.112298f, 0.830550f, 0.857847f, -0.774066f, 0.048746f, 0.173347f, 0.374310f, 0.508790f, + 0.892303f, 0.267807f, -0.501987f, -0.221902f, 0.168966f, 0.814484f, 0.951160f, 0.628689f, 1.171188f, + 1.143145f, 0.117596f, -0.374746f, -0.281733f, 1.347597f, 0.084762f, 0.563618f, 0.110448f, -0.044959f, + 0.876871f, 1.021856f, 1.680601f, 1.617434f, 1.269441f, 0.006120f, 0.766407f, -1.697397f, 1.538109f, + -0.396112f, 0.895946f, 0.915749f, 0.115267f, 0.798699f, -0.323668f, 0.707952f, 1.253857f, 1.336810f, + 0.388281f, -0.952190f, -0.330132f, 0.567214f, 0.989273f, -0.186156f, 1.005235f, 0.538783f, 0.540277f, + 0.963361f, 0.639664f, 0.462482f, 0.698468f, 0.534505f, 0.759512f, 1.099447f, 0.316237f, 0.498906f, + 0.445022f, 0.377702f, 0.339785f, 1.007626f, 1.143811f, 0.735108f, 0.274017f, 0.909311f, 0.923411f, + 0.633157f, 0.829842f, 0.907585f, 1.018174f, -0.330711f, -1.004036f, -0.470611f, 1.595124f, 1.007291f, + 0.851145f, -0.009376f, 0.217997f, 0.887566f, 0.570916f, 1.051678f, 0.245246f, 1.265329f, -0.268568f, + 0.373900f, 1.000752f, 0.128213f, 0.902059f, -0.106214f, 0.149962f, 0.074347f, -0.311712f, 0.962363f, + 0.626313f, 1.394106f, 0.275462f, 0.349021f, 0.389777f, 1.786896f, 0.567943f, 0.881495f, 0.817815f, + -0.143777f, 1.279913f, 0.339589f, 0.467706f, -0.153407f, -0.046937f, 0.766692f, -0.240678f, 0.593997f, + 1.864102f, 0.830787f, 1.217034f, -0.176123f, 0.595660f, 0.827656f, 0.861351f, -0.710248f, 1.412525f, + 0.737254f, 0.893155f, 0.796258f, 0.917900f, 0.020787f, 0.528776f, 0.896313f, -0.023476f, 0.010474f, + -0.061789f, 0.504972f, 0.727948f, -1.691226f, -0.209513f, 0.783358f, -0.402073f, 1.660988f, 0.398667f, + 0.746822f, 0.756122f, 1.075280f, 0.117522f, 0.482337f, 0.121187f, -0.097000f, 0.026081f, 0.564591f, + 0.229187f, -0.092778f, 0.715658f, 1.202137f, -0.648320f, 0.964561f, -0.294534f, -0.047344f, 1.191577f, + -0.162200f, 1.455440f, -0.230969f, 0.864577f, 1.022470f, 0.269888f, 0.830973f, 0.796731f, 1.288781f, + -0.279808f, 1.461457f, 0.011112f, 0.661665f, 0.377751f, 0.597034f, 0.896032f, 0.864871f, 0.834800f, + -0.242229f, 0.489711f, 0.900796f, -0.769517f, 0.893398f, 0.656636f, 1.234794f, 0.229293f, 1.113528f, + -0.032344f, 0.465116f, 0.453282f, -0.855888f, 0.287742f, 1.001159f, 0.339036f, 1.053392f, 1.481772f, + 0.350476f, 0.045156f, 0.789485f, 1.194247f, 0.953225f, 0.902432f, -0.469070f, 1.620967f, 0.308757f, + 0.558440f, 0.995676f, 0.582246f, -0.395102f, 0.145108f, -1.011727f, 0.925334f, 1.007595f, 0.227135f, + 0.257161f, -0.217773f, 0.446225f, -0.090537f, 1.239298f, 0.278935f, 0.210654f, 0.565323f, 0.079686f, + -0.291973f, 0.796541f, 0.646783f, 0.600274f, -0.508244f, 1.545499f, 0.378070f, 0.104429f, 0.718873f, + 1.460520f, 1.208726f, 0.296987f, -0.352711f, -0.089663f, 0.797042f, 1.431477f, -1.527740f, 0.749836f, + 0.820989f, 0.769196f, -0.563369f, -0.191057f, 1.076530f, 0.250859f, 0.857584f, -0.346350f, 0.894605f, + 0.886723f, 0.338762f, 0.656447f, 1.024669f, 0.693469f, 0.659168f, 0.905854f, 0.224581f, 0.357502f, + -0.007332f, -0.390864f, 0.307303f, 0.571301f, 0.437075f, -0.311736f, -0.249217f, 0.585570f, -0.397286f, + 1.381788f, -0.368342f, 0.647196f, 0.809376f, -0.462215f, 0.117660f, 0.453367f, -0.164130f, 0.558638f, + -0.054476f, 0.367439f, 0.244487f, -0.175144f, -0.005267f, 1.277564f, 0.465179f, 0.811168f, -0.541415f, + -0.034877f, 0.830097f, -0.216335f, 0.466843f, 0.311936f, 0.906072f, 1.086701f, 0.017932f, 0.843566f, + 0.882529f, 1.066624f, -0.810174f, -0.560630f, 0.625432f, 1.290380f, 1.393908f, 0.789381f, 0.972095f, + 1.008577f, 0.881400f, 0.765357f, 0.556296f, 1.274361f, 2.005213f, -0.179109f, -0.167324f, 1.110618f, + 0.147224f, 1.058541f, -0.114418f, 0.418016f, 0.438177f, 1.042157f, 0.420788f, 0.554656f, 0.714696f, + 0.778748f, 1.693937f, 0.954506f, 0.957155f, 0.581167f, 0.971378f, 0.951858f, 0.841796f, 0.464267f, + 0.329466f, 1.016010f, 1.023249f, 0.637742f, 0.840873f, 0.229559f, 1.614064f, 0.264498f, -0.269998f, + 0.941741f, 0.066780f, -0.447346f, -0.301152f, 0.471130f, 0.673516f, 0.764475f, 0.221142f, 0.141437f, + 0.050416f, -0.363394f, 0.133119f, 0.851959f, 0.138650f, 1.246940f, -0.408690f, 0.153658f, 0.840290f, + 0.181189f, 0.244843f, 1.995885f, -1.411448f, 1.422581f, 0.658642f, 0.243404f, 0.442854f, 0.230959f, + -0.272532f, 0.778544f, 1.461264f, 0.670758f, 2.274148f, 0.642745f, 0.948315f}; + + std::vector swish_data = { + 0.243866f, 1.365124f, 0.050220f, 0.657257f, -0.177689f, 1.365588f, 0.618877f, 0.492453f, 0.510088f, + 0.537078f, 0.379677f, 1.201586f, 0.100271f, 0.029728f, -0.182035f, 0.077149f, 0.357653f, 0.016068f, + 0.258221f, 0.062951f, 0.028085f, 0.331049f, -0.251691f, -0.092633f, 0.412580f, 0.762192f, -0.013831f, + 0.062333f, -0.263020f, 0.143483f, 0.733510f, 0.029050f, 0.313013f, 0.102139f, -0.015817f, -0.090723f, + 0.454239f, 0.640686f, 0.067457f, 0.799871f, -0.008629f, 0.691892f, 0.476392f, 0.976283f, -0.217050f, + 0.624266f, 0.734605f, 0.229296f, 0.564844f, -0.253452f, 0.757935f, 1.399714f, 0.714629f, 0.080910f, + -0.178318f, -0.006105f, 0.820346f, 0.640120f, -0.171787f, 0.657118f, -0.213635f, 0.238256f, 0.850725f, + 1.111010f, 0.501217f, 0.206920f, 0.581032f, 0.398530f, 0.457656f, 0.408295f, 0.359337f, 0.245632f, + 0.428173f, 0.177226f, 0.881711f, 0.212098f, 0.378743f, 0.644037f, 0.252370f, 0.254082f, 0.307957f, + 0.549116f, 0.517441f, 0.727826f, -0.137456f, 1.370297f, -0.221845f, 1.072585f, 0.632512f, 0.387934f, + 0.548196f, 0.076070f, -0.007803f, -0.256636f, 0.768393f, 0.147922f, 0.904965f, 0.222318f, 0.306135f, + -0.052967f, 0.073060f, 0.947734f, 0.085489f, 0.866159f, -0.126290f, 0.127472f, 0.615866f, 0.237987f, + 0.262488f, 1.561563f, 0.207543f, -0.003415f, 0.141699f, 0.200946f, -0.188076f, 0.227198f, 0.516802f, + 0.062015f, 0.200143f, 0.632902f, -0.125327f, -0.277876f, 0.528298f, 1.722697f, 0.558246f, -0.202095f, + 1.226985f, 0.907526f, -0.159901f, 0.309820f, -0.020019f, 0.561637f, -0.181556f, 2.420778f, 0.105460f, + 0.752824f, 0.519554f, 0.823517f, -0.029484f, 0.040870f, 0.325333f, 0.022692f, 0.603779f, 0.019369f, + 0.096304f, 0.479364f, 0.425634f, -0.126475f, 0.768537f, 0.260306f, -0.187456f, 0.928184f, 0.475279f, + 0.858428f, 0.740447f, 0.062195f, 0.602277f, 0.102570f, -0.234669f, 0.035395f, 0.110958f, 0.141960f, + -0.078762f, 0.036662f, 0.950773f, 1.071117f, 0.094387f, 1.398649f, 0.039996f, 0.914138f, 0.049258f, + 0.902623f, -0.178574f, 0.579420f, 0.830635f, 0.140073f, 0.577730f, 1.082880f, 0.406556f, 0.741494f, + -0.219945f, 0.293266f, 0.017723f, 1.507298f, 0.435470f, 0.503657f, 0.547762f, 0.814111f, 0.172503f, + 0.294135f, 0.850672f, 0.277405f, 0.861257f, -0.206513f, 0.445763f, 0.882337f, 0.577514f, 0.086886f, + 0.256654f, 0.314115f, 0.855378f, 0.237559f, 0.900973f, 0.633758f, 0.870852f, 0.002510f, 0.373285f, + 0.897667f, 0.470606f, 0.673480f, 0.785239f, 0.770623f, 0.249797f, 1.283304f, -0.253514f, 0.863022f, + 0.502989f, 0.360029f, -0.038882f, 1.912126f, 0.287736f, 0.192089f, 0.558143f, 0.998140f, -0.044974f, + -0.199037f, 0.259179f, 0.197649f, -0.126316f, -0.258402f, -0.034082f, 0.522367f, -0.145276f, -0.032544f, + 0.843271f, 0.040511f, 0.272236f, -0.112454f, 0.142219f, 0.961279f, -0.267405f, 0.573859f, -0.178851f, + 0.066942f, -0.093427f, 1.220799f, 0.051105f, 0.262543f, 0.610777f, -0.197959f, 0.498287f, 0.191122f, + 0.568889f, 0.031705f, 0.380467f, 0.678707f, -0.192410f, 2.455177f, -0.054734f, 0.646839f, 0.132996f, + 1.802949f, 0.146142f, 0.352078f, 0.509331f, 0.856461f, 0.108424f, 0.392432f, -0.278360f, -0.001154f, + 1.587304f, 0.045678f, 0.327438f, -0.067726f, -0.267492f, -0.128642f, 0.448763f, -0.061851f, 0.384811f, + 0.401312f, 0.435012f, 0.522193f, 1.305406f, 0.062175f, -0.181835f, -0.167443f, -0.190078f, -0.018481f, + 0.272698f, 0.160295f, -0.009012f, 0.653681f, 0.929582f, -0.265990f, 0.746799f, 0.879794f, 0.009796f, + 0.475701f, 0.479070f, 1.076367f, 0.153117f, 0.587422f, 0.609541f, 0.054136f, 0.549380f, 0.326523f, + 0.847322f, -0.204150f, 1.050518f, 0.073575f, 1.648380f, -0.206833f, 0.600764f, 0.602378f, 0.716126f, + 0.711085f, 0.972323f, 0.375335f, 0.471277f, -0.097483f, -0.165546f, 1.120330f, 0.359888f, 0.373787f, + 0.636029f, 0.450923f, 0.269234f, 0.341984f, 0.057805f, 0.156379f, 0.056874f, 0.207681f, 0.474008f, + 0.653584f, 0.410021f, 0.208764f, 0.638058f, 0.661132f, 0.710716f, 0.671878f, -0.151240f, -0.000166f, + -0.178658f, 0.502916f, 0.415034f, 0.649641f, 0.786735f, 1.076488f, 0.889679f, 0.241274f, 0.743397f, + 0.546106f, 0.393839f, -0.169319f, 0.236975f, -0.045887f, 1.027756f, 0.494719f, -0.032942f, 0.296938f, + -0.040656f, 0.589695f, 0.848825f, -0.052676f, 0.441086f, 0.744888f, 0.943881f, 0.888123f, -0.270732f, + 1.086931f, 1.412803f, 0.381228f, 0.769628f, 0.200465f, 2.362491f, 0.711443f, 0.770470f, -0.049002f, + 0.519488f, -0.092087f, 0.427906f, 1.368965f, 0.260506f, 0.841215f, 0.064602f, 1.225849f, 0.294918f, + 0.735547f, 0.133336f, 1.254788f, 0.015386f, 0.596943f, 0.492345f, 0.581139f, 0.186914f, 1.487454f, + 1.006534f, -0.005585f, -0.001581f, 0.081065f, 0.145732f, 1.000125f, 0.206097f, 0.149164f, 0.134332f, + -0.187918f, 1.363060f, 0.694153f, 0.461047f, 0.536204f, 0.450521f, 0.062186f, 0.055983f, 0.498477f, + 0.028087f, 0.003598f, 0.016046f, 0.029994f, 0.021779f, 0.264239f, -0.033536f, 0.872580f, -0.122234f, + 0.295709f, -0.056175f, 1.027117f, -0.266875f, -0.084465f, 0.198578f, 0.650082f, 0.757945f, 0.173615f, + 0.576280f, -0.204651f, 1.137731f, 1.156216f, 0.391983f, 0.109964f, -0.173044f, 1.268957f, 0.618676f, + 0.527688f, -0.255739f, 0.731079f, 0.048719f, 0.275437f, 0.615220f, 1.936515f, 0.015392f, -0.095172f, + 0.460849f, 0.479997f, 0.508724f, -0.108577f, 0.345855f, 0.331264f, 0.699551f, 0.238672f, 0.559207f, + 0.247724f, 0.603606f, 0.871938f, 1.000926f, 0.133824f, 0.372154f, 0.840444f, 0.780221f, 0.674734f, + -0.218730f, 0.547163f, -0.185949f, 1.354473f, 0.542996f, 0.508608f, -0.011991f, 0.230168f, 0.055314f, + -0.264336f, 0.433681f, 0.713642f, 1.226827f, -0.035779f, 0.245986f, 0.969103f, -0.094573f, 0.103365f, + 0.050215f, 0.516856f, 0.383809f, 0.432058f, 0.644802f, -0.139653f, 0.169732f, 0.362299f, 0.038345f, + 0.204724f, 0.907438f, 0.013716f, 2.144154f, 1.195574f, -0.171073f, 1.521402f, -0.233436f, 0.370009f, + 0.665922f, -0.202732f, 0.536225f, 1.927784f, 0.518755f, -0.254002f, 0.083084f, 0.872584f, 0.550561f, + 0.499761f, 0.195041f, 0.621209f, 0.344486f, 0.962738f, -0.055014f, 0.451459f, 0.508984f, 0.643412f, + 0.327522f, 0.022782f, 0.415858f, 0.053347f, 0.883625f, -0.083251f, 0.221916f, 0.425394f, 0.154676f, + 0.123024f, 0.294614f, 0.168252f, 0.546489f, 0.263568f, 0.456214f, 0.687772f, -0.082351f, 0.062297f, + 0.929695f, -0.036025f, -0.162895f, -0.014807f, 0.024873f, 0.034856f, 0.000690f, -0.029181f, 0.159995f, + -0.254131f, 0.566570f, 0.455867f, 0.079798f, 0.950309f, 0.935859f, 0.043096f, -0.227895f, 0.595564f, + 0.851290f, 0.203369f, 0.584960f, 0.329856f, 0.829387f, -0.242043f, 0.415417f, 0.202362f, -0.091046f, + -0.218020f, 0.682626f, 0.833448f, 0.276201f, 0.786472f, -0.143303f, 0.214515f, 0.544456f, 1.164481f, + -0.092349f, -0.155917f, 0.424301f, 0.993236f, 0.547438f, -0.093257f, 0.341603f, 0.298046f, -0.003180f, + 0.245189f, 0.285878f, 0.281830f, 0.460745f, 0.035461f, 0.330722f, 0.108188f, 0.479376f, 0.591785f, + 0.748902f, 0.827456f, -0.255522f, 0.839288f, 0.074588f, 0.277417f, 0.388207f, 0.703465f, 0.262941f, + 0.902279f, 0.740486f, 0.082008f, 0.604751f, 0.302422f, -0.161350f, 0.306618f, 0.293111f, 0.437553f, + 0.536501f, 1.002620f, 0.086627f, 0.108969f, 0.629911f, 0.142787f, 0.823012f, 0.409770f, 0.318761f, + 0.615761f, 1.036466f, 0.129264f, -0.251066f, 0.441357f, 0.853822f, -0.159001f, 0.392318f, 0.485029f, + 1.345469f, 0.948325f, 0.845997f, 1.161375f, 0.495040f, 1.228578f, 0.504504f, -0.084682f, 1.060236f, + 0.818468f, -0.104546f, 2.138265f, 0.028526f, 0.643413f, 0.059589f, -0.257335f, 0.312156f, 0.625349f, + 0.574370f, 0.837068f, 0.578454f, 0.602389f, -0.244295f, 0.024967f, 0.094167f, 0.221779f, 0.317751f, + 0.632969f, 0.151727f, -0.189286f, -0.098691f, 0.091604f, 0.564490f, 0.686118f, 0.410026f, 0.894037f, + 0.866797f, 0.062251f, -0.152669f, -0.121153f, 1.069637f, 0.044176f, 0.359187f, 0.058271f, -0.021974f, + 0.619223f, 0.751405f, 1.416720f, 1.349653f, 0.990985f, 0.003070f, 0.523259f, -0.262766f, 1.266156f, + -0.159335f, 0.636225f, 0.654005f, 0.060952f, 0.550860f, -0.135870f, 0.474291f, 0.975459f, 1.058707f, + 0.231365f, -0.265132f, -0.138064f, 0.361950f, 0.721125f, -0.084439f, 0.735919f, 0.340258f, 0.341388f, + 0.697275f, 0.418773f, 0.283780f, 0.466470f, 0.337023f, 0.517416f, 0.824758f, 0.182914f, 0.310421f, + 0.271221f, 0.224097f, 0.198482f, 0.738142f, 0.867442f, 0.496878f, 0.155663f, 0.648210f, 0.660919f, + 0.413581f, 0.577837f, 0.646659f, 0.747968f, -0.138260f, -0.269231f, -0.180937f, 1.326083f, 0.737830f, + 0.596488f, -0.004666f, 0.120832f, 0.628741f, 0.364801f, 0.779395f, 0.137584f, 0.986884f, -0.116360f, + 0.221499f, 0.731756f, 0.068210f, 0.641700f, -0.050289f, 0.080593f, 0.038555f, -0.131760f, 0.696361f, + 0.408138f, 1.117023f, 0.156582f, 0.204659f, 0.232396f, 1.530559f, 0.362511f, 0.623333f, 0.567378f, + -0.066730f, 1.001448f, 0.198351f, 0.287564f, -0.070832f, -0.022918f, 0.523501f, -0.105927f, 0.382702f, + 1.613891f, 0.578661f, 0.938992f, -0.080327f, 0.384000f, 0.575932f, 0.605480f, -0.234058f, 1.135901f, + 0.498675f, 0.633730f, 0.548760f, 0.655944f, 0.010501f, 0.332705f, 0.636554f, -0.011600f, 0.005264f, + -0.029940f, 0.314914f, 0.490896f, -0.263180f, -0.093822f, 0.537700f, -0.161157f, 1.395845f, 0.238550f, + 0.506708f, 0.514549f, 0.801729f, 0.062210f, 0.298229f, 0.064261f, -0.046150f, 0.013210f, 0.359935f, + 0.127668f, -0.044238f, 0.480672f, 0.924329f, -0.222613f, 0.698375f, -0.125735f, -0.023112f, 0.913967f, + -0.074537f, 1.180120f, -0.102207f, 0.608330f, 0.751979f, 0.153044f, 0.578823f, 0.549166f, 1.010328f, + -0.120458f, 1.186345f, 0.005587f, 0.436457f, 0.224131f, 0.385073f, 0.636303f, 0.608590f, 0.582164f, + -0.106517f, 0.303639f, 0.640568f, -0.243616f, 0.633947f, 0.432398f, 0.956541f, 0.127733f, 0.838249f, + -0.015911f, 0.285687f, 0.277146f, -0.255225f, 0.164428f, 0.732134f, 0.197982f, 0.781012f, 1.207408f, + 0.205636f, 0.023088f, 0.542946f, 0.916584f, 0.688003f, 0.642034f, -0.180515f, 1.353391f, 0.178024f, + 0.355219f, 0.727050f, 0.373560f, -0.159025f, 0.077809f, -0.269769f, 0.662657f, 0.738113f, 0.126410f, + 0.145023f, -0.097077f, 0.272082f, -0.043221f, 0.961003f, 0.158794f, 0.116380f, 0.360497f, 0.041430f, + -0.124825f, 0.549003f, 0.424474f, 0.387609f, -0.190899f, 1.273897f, 0.224349f, 0.054939f, 0.483341f, + 1.185376f, 0.930808f, 0.170383f, -0.145573f, -0.042823f, 0.549434f, 1.155390f, -0.272434f, 0.509246f, + 0.570132f, 0.525628f, -0.204372f, -0.086430f, 0.802915f, 0.141080f, 0.602158f, -0.143482f, 0.635026f, + 0.627989f, 0.197799f, 0.432245f, 0.754035f, 0.462362f, 0.434440f, 0.645105f, 0.124847f, 0.210367f, + -0.003652f, -0.157717f, 0.177076f, 0.365097f, 0.265550f, -0.131768f, -0.109161f, 0.376140f, -0.159695f, + 1.104433f, -0.150630f, 0.424806f, 0.560069f, -0.178628f, 0.062287f, 0.277207f, -0.075345f, 0.355371f, + -0.026496f, 0.217098f, 0.137113f, -0.079923f, -0.002627f, 0.999099f, 0.285733f, 0.561619f, -0.199164f, + -0.017134f, 0.578059f, -0.096513f, 0.286938f, 0.180099f, 0.645301f, 0.812592f, 0.009046f, 0.589835f, + 0.624253f, 0.793519f, -0.249415f, -0.203734f, 0.407439f, 1.011931f, 1.116821f, 0.542856f, 0.705290f, + 0.739027f, 0.623249f, 0.522368f, 0.353579f, 0.995898f, 1.767281f, -0.081556f, -0.076679f, 0.835456f, + 0.079021f, 0.785873f, -0.053940f, 0.252067f, 0.266335f, 0.770431f, 0.254018f, 0.352326f, 0.479874f, + 0.533761f, 1.430939f, 0.689173f, 0.691594f, 0.372724f, 0.704632f, 0.686755f, 0.588283f, 0.285072f, + 0.191627f, 0.745949f, 0.752707f, 0.417238f, 0.587475f, 0.127896f, 1.346089f, 0.149638f, -0.116884f, + 0.677536f, 0.034505f, -0.174461f, -0.128073f, 0.290052f, 0.446063f, 0.521620f, 0.122747f, 0.075711f, + 0.025843f, -0.149042f, 0.070983f, 0.597205f, 0.074123f, 0.968585f, -0.163160f, 0.082720f, 0.586964f, + 0.098779f, 0.137334f, 1.757106f, -0.276652f, 1.146234f, 0.434016f, 0.136440f, 0.269671f, 0.128756f, + -0.117812f, 0.533588f, 1.186146f, 0.443822f, 2.062000f, 0.421238f, 0.683523f}; + + // Test float16, without activation + int min_cuda_architecture = 530; + if (HasCudaEnvironment(min_cuda_architecture)) { + OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 32); + test.AddAttribute("activation", 0); + + test.AddInput("X", dims, ToFloat16(input_data)); + test.AddInput("gamma", {C}, gamma_data); + + test.AddInput("beta", {C}, beta_data); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims, ToFloat16(norm_data), false, rel_error, abs_error); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + + // Test float32, with activation + if (HasCudaEnvironment(0)) { + OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 32); + test.AddAttribute("activation", 1); + + test.AddInput("X", dims, input_data); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.01f; + test.AddOutput("Y", dims, swish_data, false, rel_error, abs_error); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc new file mode 100644 index 0000000000000..6cffaa4d57bf4 --- /dev/null +++ b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" + +using namespace std; +namespace onnxruntime { +namespace test { + +namespace { + +struct NhwcConvOpAndTestAttributes { + string auto_pad; + vector dilations; + int64_t group; + vector kernel_shape; + vector pads; + vector strides; + std::unordered_set excluded_providers; +}; + +void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes, + const vector>& inputs, + const vector>& input_shapes, + const std::initializer_list& expected_output, + const vector& expected_output_shape, + bool use_float16, + bool weight_is_initializer = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda) { + OpTester test("NhwcConv", 1, onnxruntime::kMSDomain); + test.AddAttribute("group", attributes.group); + test.AddAttribute("kernel_shape", attributes.kernel_shape); + + if (!attributes.dilations.empty()) { + test.AddAttribute("dilations", attributes.dilations); + } + + // Only one of pads / auto_pad can be present + if (!attributes.pads.empty()) { + test.AddAttribute("pads", attributes.pads); + } else { + test.AddAttribute("auto_pad", attributes.auto_pad); + } + + if (!attributes.strides.empty()) { + test.AddAttribute("strides", attributes.strides); + } + + ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); + const char* szNames[] = {"X", "W", "B"}; + + if (use_float16) { + test.AddInput(szNames[0], input_shapes[0], ToFloat16(inputs[0])); + test.AddInput(szNames[1], input_shapes[1], ToFloat16(inputs[1]), weight_is_initializer); + if (inputs.size() == 3) { + test.AddInput(szNames[2], input_shapes[2], ToFloat16(inputs[2])); + } + test.AddOutput("Y", expected_output_shape, ToFloat16(expected_output)); + } else { + test.AddInput(szNames[0], input_shapes[0], inputs[0]); + test.AddInput(szNames[1], input_shapes[1], inputs[1], weight_is_initializer); + if (inputs.size() == 3) { + test.AddInput(szNames[2], input_shapes[2], inputs[2]); + } + test.AddOutput("Y", expected_output_shape, expected_output); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +void RunNhwcConv(const NhwcConvOpAndTestAttributes& attributes, + const vector>& inputs, + const vector>& input_shapes, + const std::initializer_list& expected_output, + const vector& expected_output_shape) { + bool use_float16 = true; + bool weight_is_initializer = true; + TestNhwcConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, use_float16, weight_is_initializer); + + use_float16 = false; + weight_is_initializer = false; + TestNhwcConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, use_float16, weight_is_initializer); +} + +} // namespace + +TEST(NhwcConvTest, Conv2D_2) { + NhwcConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + 0.45246148109436035f, 0.15498268604278564f, 0.11199361085891724f, -0.39421093463897705f, + 0.2626858949661255f, 0.13414543867111206f, -0.27184486389160156f, -0.43028733134269714f, + -0.26825493574142456f, 0.3893144130706787f, -0.13631996512413025f, -0.009590476751327515f, + -0.48771554231643677f, -0.25256502628326416f, -0.2812897562980652f, 0.4043201804161072f, + 0.07795023918151855f, 0.326981782913208f, 0.13114392757415771f, -0.4416425824165344f, + 0.12446999549865723f, 0.36739975214004517f, 0.1698915958404541f, 0.2008744478225708f, + 0.23339951038360596f, 0.38613730669021606f, 0.11117297410964966f, 0.3877097964286804f, + 0.20812749862670898f, -0.34297940135002136f, -0.029246658086776733f, -0.20483523607254028f, + -0.19244328141212463f, -0.11104947328567505f, -0.32830488681793213f, -0.01800677180290222f, + 0.3618946671485901f, -0.40949052572250366f, -0.18248388171195984f, -0.3349453806877136f, + -0.34091079235076904f, 0.006497859954833984f, 0.4537564516067505f, 0.08006560802459717f, + -0.14788749814033508f, 0.034442365169525146f, -0.33322954177856445f, 0.06049239635467529f, + 0.42619407176971436f}; + vector X_shape = {1, 7, 7, 1}; + vector W = {-0.4406261742115021f}; + vector W_shape = {1, 1, 1, 1}; + vector Y_shape = {1, 7, 7, 1}; + auto expected_vals = { + -0.19936637580394745f, -0.06828942894935608f, -0.04934731498360634f, 0.17369966208934784f, + -0.11574628204107285f, -0.05910799279808998f, 0.1197819635272026f, 0.18959586322307587f, + 0.1182001456618309f, -0.17154212296009064f, 0.06006614491343498f, 0.0042258151806890965f, + 0.21490024030208588f, 0.11128675937652588f, 0.12394362688064575f, -0.17815405130386353f, + -0.034346915781497955f, -0.14407673478126526f, -0.05778544768691063f, 0.19459928572177887f, + -0.05484473705291748f, -0.16188594698905945f, -0.07485868036746979f, -0.08851054310798645f, + -0.10284193605184555f, -0.17014220356941223f, -0.04898572340607643f, -0.17083507776260376f, + -0.09170642495155334f, 0.1511256992816925f, 0.012886842712759972f, 0.09025576710700989f, + 0.08479554951190948f, 0.0489313043653965f, 0.14465972781181335f, 0.007934254594147205f, + -0.15946026146411896f, 0.1804322451353073f, 0.08040717244148254f, 0.1475857049226761f, + 0.15021422505378723f, -0.0028631272725760937f, -0.19993697106838226f, -0.03527900204062462f, + 0.06516310572624207f, -0.015176207758486271f, 0.14682966470718384f, -0.02665453404188156f, + -0.18779225647449493f}; + RunNhwcConv(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +TEST(NhwcConvTest, Conv2D_Bias_1) { + NhwcConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; + vector X_shape = {1, 3, 3, 1}; + vector W = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + vector W_shape = {2, 2, 2, 1}; + vector Y_shape = {1, 2, 2, 2}; + vector B = {1.0f, -1.0f}; + vector B_shape = {2}; + auto expected_vals = {13.0f, 11.0f, 17.0f, 15.0f, 25.0f, 23.0f, 29.0f, 27.0f}; + + RunNhwcConv(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); +} + +TEST(NhwcConvTest, Conv2D_AutoPad1) { + NhwcConvOpAndTestAttributes attrs = { + "SAME_UPPER", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + {}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = vector(25, 1.0f); + vector X_shape = {1, 5, 5, 1}; + vector W = {0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f}; + + vector W_shape = {1, 3, 3, 1}; + vector Y_shape = {1, 5, 5, 1}; + auto expected_vals = {24.0f, 33.0f, 33.0f, 33.0f, 20.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 12.0f, 15.0f, 15.0f, 15.0f, 8.0f}; + RunNhwcConv(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +TEST(NhwcConvTest, Conv2D_AutoPad2) { + NhwcConvOpAndTestAttributes attrs = { + "SAME_LOWER", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + {}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {1.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f}; + vector X_shape = {1, 5, 5, 1}; + vector W = {0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f}; + + vector W_shape = {1, 3, 3, 1}; + vector Y_shape = {1, 5, 5, 1}; + auto expected_vals = {11.0f, 22.0f, 11.0f, 22.0f, 11.0f, + 12.0f, 24.0f, 12.0f, 24.0f, 12.0f, + 12.0f, 24.0f, 12.0f, 24.0f, 12.0f, + 12.0f, 24.0f, 12.0f, 24.0f, 12.0f, + 5.0f, 10.0f, 5.0f, 10.0f, 5.0f}; + RunNhwcConv(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 74d20295a0a63..657d52cc15a31 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -40,6 +40,7 @@ def test_multi_head_attention_fusion(self): onnx.save(model, model_path) options = FusionOptions("bert") options.use_multi_head_attention = True + options.use_raw_attention_mask(True) optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) self.verify_fusion(optimized_model, "attention_mha.onnx") @@ -49,7 +50,9 @@ def test_attention_fusion(self): dir = "." model_path = os.path.join(dir, "attention.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) self.verify_fusion(optimized_model, "attention_opt.onnx") @@ -64,7 +67,9 @@ def test_attention_fusion_pruned_model(self): dir = "." model_path = os.path.join(dir, "pruned_attention.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) self.verify_fusion(optimized_model, "pruned_attention_opt.onnx") @@ -80,7 +85,9 @@ def test_attention_fusion_reverse_add_order(self): dir = "." model_path = os.path.join(dir, "bert_attention_reverse_add_order.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) # reverse add input order will get same optimized model @@ -96,7 +103,9 @@ def test_attention_fusion_for_varied_qkv_dimensions(self): dir = "." model_path = os.path.join(dir, "attention_with_varied_qkv.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) self.verify_fusion(optimized_model, "attention_with_varied_qkv_opt.onnx") @@ -113,7 +122,9 @@ def test_attention_fusion_for_varied_qkv_dimensions_with_wrong_opt_parameters(se onnx.save(model, model_path) # wrong num_heads and hidden_size - optimized_model = optimize_model(model_path, "bert", num_heads=8, hidden_size=8) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, "bert", num_heads=8, hidden_size=8, optimization_options=options) os.remove(model_path)