diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6eb315c59bc80..b118161b1c45a 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -15,6 +15,10 @@ set(contrib_ops_excluded_files "bert/fast_gelu_impl.h" "bert/fast_gelu.cc" "bert/fast_gelu.h" + "bert/relative_attn_bias.cc" + "bert/relative_attn_bias.h" + "bert/relative_attn_bias_impl.cu" + "bert/relative_attn_bias_impl.h" "bert/skip_layer_norm.cc" "bert/skip_layer_norm.h" "bert/skip_layer_norm_impl.cu" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 8cd6d4c9e26f1..f01a7ab14a61e 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -30,6 +30,7 @@ Do not modify directly.* * com.microsoft.FusedConv * com.microsoft.FusedGemm * com.microsoft.FusedMatMul + * com.microsoft.GatedRelativePositionBias * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu @@ -152,7 +153,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
-
extra_add (optional) : T
+
relative_position_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
@@ -1608,6 +1609,58 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GatedRelativePositionBias** + + query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) + gate_u, gate_r = torch.sigmoid( + self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0 + rel_pos_bias = gate_u_1 * rel_pos + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
num_heads : int (required)
+
Number of attention heads
+
+ +#### Inputs + +
+
query_layer : T
+
tensor with shape (batch_size, seq_len, num_heads x head_size)
+
query_bias : T
+
1-d tensor with shape (num_heads x head_size)
+
rel_pos : T
+
tensor with shape (1, num_head, seq_len, seq_len)
+
weight : T
+
gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2
+
bias : T
+
bias for the gated_ur_linear, shape (D)
+
eco_a : T
+
tensor of shape (1, num_heads, 1, 1)
+
+ +#### Outputs + +
+
output : T
+
output tensor with shape (batch_size, num_heads, seq_len, seq_len)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
+ + ### **com.microsoft.GatherND** Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather @@ -2222,7 +2275,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (2 - 5) +#### Inputs (2 - 6)
query : T
@@ -2235,6 +2288,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
+
relative_position_bias (optional) : T
+
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
#### Outputs @@ -3221,7 +3276,7 @@ This version of the operator has been available since version 1 of the 'com.micr left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. - Current version does not support past/present, extra_add and qkv_hidden_sizes. + Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. #### Version @@ -3286,7 +3341,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : Q
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
-
extra_add (optional) : S
+
relative_position_bias (optional) : S
additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 286cad61d599f..29043259064a5 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -417,7 +417,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| @@ -785,7 +785,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| @@ -803,6 +803,7 @@ Do not modify directly.* |FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| @@ -810,11 +811,11 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* extra_add:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| +|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLongformerAttention|*in* input:**Q**
*in* scale_input:**S**
*in* weight:**Q**
*in* scale_weight:**S**
*in* bias:**S**
*in* scale_bias:**S**
*in* scale_qkv_gemm:**S**
*in* mask:**F**
*in* global_weight:**Q**
*in* scale_global_weight:**S**
*in* global_bias:**S**
*in* scale_global_gemm:**S**
*in* global:**G**
*in* scale_output:**S**
*out* output:**Q**|1+|**F** = tensor(float16)
**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| @@ -1159,7 +1160,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 47db3fe558ce8..6aa0e726afe1b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -198,7 +198,7 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_); @@ -208,7 +208,7 @@ Status Attention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, ¶meters)); const int batch_size = parameters.batch_size; @@ -331,7 +331,7 @@ Status Attention::Compute(OpKernelContext* context) const { return ApplyAttention(Q, K, V, mask_index, past, output, batch_size, sequence_length, parameters.head_size, parameters.v_head_size, parameters.v_hidden_size, - extra_add_qk, context); + relative_position_bias, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index affe7cab1d858..e75f68ea53c7c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -12,7 +12,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const Tensor* past_seq_len) const { // Abbreviation and Meanings: @@ -37,7 +37,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // extra_add_qk : (B, N, S, T) or NULL + // relative_position_bias : (B, N, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) @@ -49,9 +49,9 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger // than hidden dimension of Q, K and V. - if (past != nullptr && extra_add_qk != nullptr) { - // past is used on GPT-2 model with past state, we don't have a case for extra add qk yet - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and extra_add_qk"); + if (past != nullptr && relative_position_bias != nullptr) { + // past is used on GPT-2 model with past state, we don't have a case for relative position bias yet + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and relative_position_bias"); } const auto& dims = input_shape.GetDims(); @@ -191,34 +191,34 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } } - if (extra_add_qk != nullptr) { - const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims(); + if (relative_position_bias != nullptr) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - if (extra_add_qk_dims.size() != 4) { + if (relative_position_bias_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' is expected to have 4 dimensions, got ", - extra_add_qk_dims.size()); + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); } - if (extra_add_qk_dims[0] != batch_size) { + if (relative_position_bias_dims[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ", - extra_add_qk_dims[0]); + "Input 'relative_position_bias' dimension 0 should be same as batch_size, got ", + relative_position_bias_dims[0]); } - if (extra_add_qk_dims[1] != num_heads_) { + if (relative_position_bias_dims[1] != num_heads_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ", - extra_add_qk_dims[1]); + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); } - if (extra_add_qk_dims[2] != sequence_length) { + if (relative_position_bias_dims[2] != sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ", - extra_add_qk_dims[2]); + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); } - if (extra_add_qk_dims[3] != total_sequence_length) { + if (relative_position_bias_dims[3] != total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 3 should be same as total_sequence_length, got ", - extra_add_qk_dims[3]); + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); } } @@ -320,7 +320,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { @@ -328,7 +328,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, extra_add_qk, parameters, past_seq_len); + return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, relative_position_bias, parameters, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 2c49f196d52d8..2e077da2853d0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -18,7 +18,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, // for CUDA const Tensor* past_seq_len = nullptr) const; @@ -61,7 +61,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const Tensor* past_seq_len = nullptr) const; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 0185fa9ea09a0..70d71ffb6ee40 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -19,18 +19,18 @@ class AttentionCPUBase : public AttentionBase { : AttentionBase(info, require_same_hidden_size) {} template - Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH - const T* K, // K data with shape BxNxSxH - const T* V, // V value with size BxNxSxH_v - const Tensor* mask_index, // mask index. nullptr if no mask or its size is B - const Tensor* past, // past state - Tensor* output, // output tensor - int batch_size, // batch size (B) - int sequence_length, // sequence length (S) - int qk_head_size, // head size of Q or K (H) - int v_head_size, // head size of V (H_v) - int v_hidden_size, // hidden size of V (D_v) - const Tensor* extra_add_qk, // extra add in QK. Its size is BxNxSxT + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxSxH + const T* V, // V value with size BxNxSxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past, // past state + Tensor* output, // output tensor + int batch_size, // batch size (B) + int sequence_length, // sequence length (S) + int qk_head_size, // head size of Q or K (H) + int v_head_size, // head size of V (H_v) + int v_hidden_size, // hidden size of V (D_v) + const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT OpKernelContext* context) const { const int kv_sequence_length = sequence_length; @@ -67,16 +67,16 @@ class AttentionCPUBase : public AttentionBase { const T* past_data = past != nullptr ? past->Data() : nullptr; T* present_data = present != nullptr ? present->MutableData() : nullptr; - const T* extra_add_qk_data = nullptr; - if (extra_add_qk != nullptr) { - extra_add_qk_data = extra_add_qk->Data(); + const T* relative_position_bias_data = nullptr; + if (relative_position_bias != nullptr) { + relative_position_bias_data = relative_position_bias->Data(); } ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims, static_cast(mask_data), has_unidirectional, batch_size, sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, - past_data, present_data, tp, extra_add_qk_data); + past_data, present_data, tp, relative_position_bias_data); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -112,7 +112,7 @@ class AttentionCPUBase : public AttentionBase { const T* past, // past state T* present, // present state ThreadPool* tp, // thread pool - const T* extra_add_qk_data // extra add matrix with shape BxNxSxT + const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT ) const { const int total_sequence_length = past_sequence_length + sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H @@ -175,9 +175,9 @@ class AttentionCPUBase : public AttentionBase { } } - if (extra_add_qk_data != nullptr) { + if (relative_position_bias_data != nullptr) { for (int j = 0; j < sequence_length * total_sequence_length; j++) { - output[j] += extra_add_qk_data[output_offset + j]; + output[j] += relative_position_bias_data[output_offset + j]; } } } diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 8c3af05972c95..ee1720b9f43bb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -17,6 +17,7 @@ Status CheckInputs(const T* query, const T* value, const T* bias, const T* key_padding_mask, + const T* relative_position_bias, void* parameters, int num_heads, float mask_filter_value, @@ -26,6 +27,7 @@ Status CheckInputs(const T* query, // value (V) : (B, L, D_v) // bias (Q/K/V) : (D + D + D_v) // key_padding_mask (K/V) : (B) or (B, L) or None + // relative_position_bias : (B, 1, S, L) // When packed kv is used: // key (K) : (B, L, N, 2, H) // value (V) : None @@ -120,6 +122,36 @@ Status CheckInputs(const T* query, v_hidden_size = static_cast(value_dims[2]); } + if (relative_position_bias != nullptr) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + + if (relative_position_bias_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); + } + if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", + relative_position_bias_dims[0]); + } + if (relative_position_bias_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); + } + if (relative_position_bias_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); + } + if (relative_position_bias_dims[3] != kv_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); + } + } + if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); output_parameters->batch_size = batch_size; diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 64c17b7767e4f..e7df84c1b0066 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -160,7 +160,7 @@ Status QAttention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past_tensor, - nullptr, // extra_add_qk + nullptr, // relative_position_bias nullptr // parameters )); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 4a6d2dc137139..1ab89b525eae5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -59,7 +59,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(kPastInputIndex); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); @@ -69,7 +69,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -105,7 +105,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; bool use_causal_fused_runner = !disable_fused_runner_ && (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == extra_add_qk && + nullptr == relative_position_bias && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, @@ -125,7 +125,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { (nullptr == mask_index || is_mask_1d_seq_len) && nullptr == past && nullptr == present && - nullptr == extra_add_qk && + nullptr == relative_position_bias && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, enable_trt_flash_attention_, false); @@ -151,7 +151,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == mask_index && // TODO: support 1D mask nullptr == past && nullptr == present && - nullptr == extra_add_qk && + nullptr == relative_position_bias && (sizeof(T) == 2 || // sequence length threshold is 0 in FP16 parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2); @@ -203,7 +203,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); - data.extra_add_qk = (nullptr == extra_add_qk) ? nullptr : reinterpret_cast(extra_add_qk->Data()); + data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 8c7ef9f919519..fcf86637350b6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -665,7 +665,7 @@ Status QkvToContext( T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, data.extra_add_qk, scratch1, scratch2, + mask_index, nullptr, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); @@ -675,10 +675,10 @@ Status QkvToContext( const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, data.extra_add_qk, scratch1, scratch2, parameters.is_unidirectional)); + mask_index, mask_start, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( - ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.extra_add_qk, + ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional)); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d98a0380c479b..2ecda71479c52 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -41,7 +41,7 @@ struct AttentionData { const int* mask_index; gsl::span mask_index_dims; const T* past; - const T* extra_add_qk; + const T* relative_position_bias; T* workspace; T* output; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index 16b3cf053b586..92851c446d48f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -377,11 +377,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, float thread_data = -CUDART_INF_F; if (threadIdx.x < all_sequence_length) { - if (add_before_softmax == nullptr) { - thread_data = float(input[index]) * rsqrt_head_size; - } else { - thread_data = float(input[index] + add_before_softmax[index]) * rsqrt_head_size; - } + thread_data = float(input[index]) * rsqrt_head_size; const int sequence_index = blockIdx.x % sequence_length; if (is_unidirectional) { @@ -412,6 +408,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, thread_data = -CUDART_INF_F; } } + + if (add_before_softmax != nullptr) { + thread_data += float(add_before_softmax[index]); + } } if (skip_softmax) { diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 93e5e59ed00ae..57a3a310a0dd6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -62,6 +62,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* bias = context->Input(3); const Tensor* key_padding_mask = context->Input(4); + const Tensor* relative_position_bias = context->Input(5); auto& device_prop = GetDeviceProp(); AttentionParameters parameters; @@ -70,6 +71,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { value, bias, key_padding_mask, + relative_position_bias, ¶meters, num_heads_, mask_filter_value_, @@ -94,6 +96,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && + nullptr == relative_position_bias && (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, @@ -112,6 +115,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = !disable_fused_runner_ && fused_cross_attention_kernel == nullptr && + nullptr == relative_position_bias && value != nullptr && // fused runner requires packed qkv instead of packed kv (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && @@ -143,6 +147,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !disable_memory_efficient_attention_ && is_long_sequence && nullptr == key_padding_mask && // TODO: support 1D mask + nullptr == relative_position_bias && has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; @@ -171,7 +176,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past = nullptr; - data.extra_add_qk = nullptr; + data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index af13efe0e2fbc..111fed04639e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -3,7 +3,15 @@ #include "core/providers/cuda/cuda_common.h" #include "relative_attn_bias.h" +#include "core/common/safeint.h" #include "relative_attn_bias_impl.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + namespace onnxruntime { namespace contrib { @@ -20,7 +28,16 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - RelPosAttnBias); + RelPosAttnBias); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GatedRelativePositionBias, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + GatedRelativePositionBias); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) @@ -69,6 +86,108 @@ Status RelPosAttnBias::ComputeInternal(OpKernelContext* context) const { device_prop.maxThreadsPerBlock); } +template +GatedRelativePositionBias::GatedRelativePositionBias(const OpKernelInfo& info) : CudaKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = SafeInt(num_heads); +} + +template +Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) const { + const Tensor& query_tensor = *context->Input(0); + const Tensor& query_bias_tensor = *context->Input(1); + const Tensor& rel_pos_tensor = *context->Input(2); + const Tensor& weight_tensor = *context->Input(3); + const Tensor& bias_tensor = *context->Input(4); + const Tensor& eco_a_tensor = *context->Input(5); + + const auto& query_dims = query_tensor.Shape().GetDims(); + ORT_ENFORCE(query_dims.size() == 3); + ORT_ENFORCE(query_dims[2] > 0); + ORT_ENFORCE(query_dims[2] % num_heads_ == 0); + const auto batch_size = SafeInt(query_dims[0]); + const auto seq_len = SafeInt(query_dims[1]); + const auto head_size = SafeInt(query_dims[2] / num_heads_); + + ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1); + ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]); + + const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims(); + ORT_ENFORCE(rel_pos_dims.size() == 4); + ORT_ENFORCE(rel_pos_dims[0] == 1); + ORT_ENFORCE(rel_pos_dims[1] == num_heads_); + ORT_ENFORCE(rel_pos_dims[2] == seq_len); + ORT_ENFORCE(rel_pos_dims[3] == seq_len); + + const auto& weight_dims = weight_tensor.Shape().GetDims(); + ORT_ENFORCE(weight_dims.size() == 2); + ORT_ENFORCE(weight_dims[0] == head_size); + ORT_ENFORCE((weight_dims[1] > 0) && (weight_dims[1] % 2 == 0)); + + ORT_ENFORCE(bias_tensor.Shape().NumDimensions() == 1); + ORT_ENFORCE(bias_tensor.Shape()[0] == weight_dims[1]); + + const auto D = SafeInt(weight_dims[1]); + + const auto& eco_a_dims = eco_a_tensor.Shape().GetDims(); + ORT_ENFORCE(eco_a_dims.size() == 4); + ORT_ENFORCE(eco_a_dims[0] == 1); + ORT_ENFORCE(eco_a_dims[1] == num_heads_); + ORT_ENFORCE(eco_a_dims[2] == 1); + ORT_ENFORCE(eco_a_dims[3] == 1); + + Tensor* output = context->Output(0, {batch_size, num_heads_, seq_len, seq_len}); + + auto& device_prop = GetDeviceProp(); + cublasHandle_t cublas = GetCublasHandle(context); + + typedef typename ToCudaType::MappedType CudaT; + const auto BNS = batch_size * num_heads_ * seq_len; + const size_t elements_in_query = (size_t)BNS * (size_t)head_size; + const size_t elements_after_gemm = (size_t)BNS *(size_t)D; + size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0); + auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream()); + + // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH) + constexpr int format = 1; + constexpr int total_maxtrix = 1; + constexpr int num_matrix_to_transpose = 1; + LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock, + batch_size, seq_len, num_heads_, head_size, + reinterpret_cast(query_tensor.template Data()), + reinterpret_cast(query_bias_tensor.template Data()), + reinterpret_cast(workspace.get()), + false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix); + + // reuse output if possible + CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast(workspace.get()) + elements_in_query) + : reinterpret_cast(output->template MutableData()); + int ld_gemm_output = max(seq_len, D); + + const CudaT one = ToCudaType::FromFloat(1.0f); + const CudaT zero = ToCudaType::FromFloat(0.0f); + + // ([b*n*s, h] * [h, D]), CUDA assumes col-major + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + D, BNS, head_size, &one, + reinterpret_cast(weight_tensor.template Data()), (int)D, + reinterpret_cast(workspace.get()), (int)head_size, + &zero, gemm_output, ld_gemm_output, device_prop)); + + auto status = LaunchGatedRelativePositionBiasKernel( + device_prop, Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(rel_pos_tensor.template Data()), + reinterpret_cast(gemm_output), + reinterpret_cast(bias_tensor.template Data()), + reinterpret_cast(eco_a_tensor.template Data()), + batch_size, num_heads_, seq_len, D, ld_gemm_output); + + return status; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h index b9674f6f35091..3bf4e730e29f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h @@ -22,6 +22,18 @@ class RelPosAttnBias final : public CudaKernel { bool is_bidirectional_; }; +template +class GatedRelativePositionBias final : public CudaKernel { + public: + GatedRelativePositionBias(const OpKernelInfo& op_kernel_info); + + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int num_heads_; +}; + + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index e333152cb5bcf..938496b058025 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -36,7 +36,7 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, const bool is_bidirectional, const int max_distance) { const int head_id = blockIdx.x; - for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { + for (int seq_id = blockDim.x * blockIdx.y + threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { int row_id = seq_id / seq_len; int col_id = seq_id % seq_len; @@ -149,6 +149,122 @@ template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream, const bool is_bidirectional, const int max_threads_per_block); +template +__global__ void GatedRelativePositionBiasKernelSmallD( + T* output, // (batch_size, num_heads, seq_len, seq_len) + const T* rel_pos, // (1, num_heads, seq_len, seq_len) + const T* qw, // (batch_size, num_heads, seq_len, D) + const T* bias, // (D) + const T* eco_a, // (1, num_heads, 1, 1) + const int D, + const int ldqw) { + __shared__ float gate[1]; + + const int seq_len = gridDim.x; + const int num_heads = gridDim.y; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + rel_pos += ((int64_t)n * seq_len + s) * seq_len; + output += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * seq_len; + qw += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * ldqw; + + float val = 0.0f; + if (threadIdx.x < D) { + val = (float)qw[threadIdx.x] + (bias ? (float)bias[threadIdx.x] : 0.0f); + } + + float u = (threadIdx.x < D / 2) ? val : 0.0f; +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + u += __shfl_down_sync(0xffffffff, u, offset); + } + + float r = (threadIdx.x >= D / 2) ? val : 0.0f; +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + r += __shfl_down_sync(0xffffffff, r, offset); + } + + if (threadIdx.x == 0) { + u = 1.0f / (1.0f + expf(-u)); + r = 1.0f / (1.0f + expf(-r)); + gate[0] = u * (r * (float)eco_a[n] - 1.0f) + 2.0f; + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < seq_len; idx += blockDim.x) { + output[idx] = (T)(gate[0] * (float)rel_pos[idx]); + } +} + +template +Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + T* output, + const T* rel_pos, + const T* qw, // query * weight + const T* bias, + const T* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw) { + ORT_ENFORCE(D <= 32 && D > 0 && (D % 2 == 0)); + ORT_ENFORCE(ldqw == seq_len || ldqw == D); + + int tpb = std::max(32, std::max(D, seq_len)); + tpb = std::min(tpb, device_prop.maxThreadsPerBlock); + + // round up tpb to power of 2 + --tpb; + tpb |= (tpb >> 1); + tpb |= (tpb >> 2); + tpb |= (tpb >> 4); + tpb |= (tpb >> 8); + tpb |= (tpb >> 16); + tpb++; + + dim3 block(tpb); + dim3 grid(seq_len, num_heads, batch_size); + + GatedRelativePositionBiasKernelSmallD<<>>( + output, rel_pos, qw, bias, eco_a, D, ldqw); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + float* output, + const float* rel_pos, + const float* qw, + const float* bias, + const float* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + +template Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + half* output, + const half* rel_pos, + const half* qw, + const half* bias, + const half* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h index 5a1a229ab6077..5c7c98f55f3f5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h @@ -22,6 +22,21 @@ Status LaunchRelPosAttnBiasKernel( const int max_threads_per_block ); +template +Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + T* output, + const T* rel_pos, + const T* qw, // from query * weight + const T* bias, + const T* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index a239e528af148..1254ccd7e1e17 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -32,6 +32,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding); @@ -162,6 +164,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index e5ea47a6a2a5b..7cd717efc9fba 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -52,7 +52,7 @@ Status QAttention::CheckInputs(const Tensor* input, auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor, - nullptr, // extra_add_qk + nullptr, // relative_position_bias parameters, device_prop.maxThreadsPerBlock)); @@ -198,7 +198,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); - data.extra_add_qk = nullptr; // add_qk is not supported in quantized attention + data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 204c786cc2c5d..8122b2de5916b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -212,7 +212,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), merged_weights_shape, merged_bias_shape, mask_index, nullptr, // past - nullptr, // extra_add_qk + nullptr, // relative_position_bias nullptr, // parameters device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h index 5fe62ef127800..5fb31be5fe86f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h @@ -17,4 +17,4 @@ DefineQOrderedAttentionInput(Scale_QK_Softmax, scale_QKT_softmax, 15), DefineQOrderedAttentionInput(Scale_Values_Gemm, scale_values_gemm, 16), DefineQOrderedAttentionInput(Mask_Index, mask_index, 17), DefineQOrderedAttentionInput(Past, past, 18), -DefineQOrderedAttentionInput(Extra_Add, extra_add, 19) +DefineQOrderedAttentionInput(relative_position_bias, relative_position_bias, 19) diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc index 756919834aef8..afc9fd9237ed7 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cc +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc @@ -39,7 +39,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), @@ -47,7 +47,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, nullptr, device_prop.maxThreadsPerBlock)); @@ -129,7 +129,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), mask_filter_value_, nullptr == past ? nullptr : past->Data(), - nullptr == extra_add_qk ? nullptr : extra_add_qk->Data(), + nullptr == relative_position_bias ? nullptr : relative_position_bias->Data(), work_space.get(), output->MutableData(), nullptr == present ? nullptr : present->MutableData()); diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 954a129be1c65..fa6cce6a64132 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -89,7 +89,7 @@ Status QkvToContext( bool is_unidirectional, int past_sequence_length, const T* past, - const T* extra_add_qk, + const T* relative_position_bias, T* present, bool use_persistent_softmax) { const int all_sequence_length = past_sequence_length + sequence_length; @@ -158,7 +158,7 @@ Status QkvToContext( T* persistent_softmax_workspace = scratch1; // replace Q*K' in place if persistent softmax is selected. ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, extra_add_qk, scratch1, scratch2, + mask_index, nullptr, relative_position_bias, scratch1, scratch2, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index @@ -166,10 +166,10 @@ Status QkvToContext( // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)); + mask_index, mask_start, relative_position_bias, scratch1, scratch2, is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, - extra_add_qk, scratch1, scratch2, is_unidirectional)); + relative_position_bias, scratch1, scratch2, is_unidirectional)); } // compute P*V (as V*P), and store in scratch3: BxNxSxH @@ -206,7 +206,7 @@ Status LaunchAttentionKernel( gsl::span mask_index_dims, const float mask_filter_value, const void* past, - const void* extra_add_qk, + const void* relative_position_bias, void* workspace, void* output, void* present) { @@ -225,7 +225,7 @@ Status LaunchAttentionKernel( is_unidirectional, past_sequence_length, reinterpret_cast(past), - reinterpret_cast(extra_add_qk), + reinterpret_cast(relative_position_bias), reinterpret_cast<__half*>(present), use_persistent_softmax); } else { @@ -240,7 +240,7 @@ Status LaunchAttentionKernel( is_unidirectional, past_sequence_length, reinterpret_cast(past), - reinterpret_cast(extra_add_qk), + reinterpret_cast(relative_position_bias), reinterpret_cast(present), use_persistent_softmax); } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 7db692083f5e5..fdc46ce2e7729 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -42,7 +42,7 @@ Status LaunchAttentionKernel( gsl::span mask_index_dims, // Mask index shape const float mask_filter_value, // Mask value for filtered out positions const void* past, // Past state input - const void* extra_add_qk, // Additional Add + const void* relative_position_bias, // Additional Add void* workspace, // Temporary buffer void* output, // Output tensor void* present // Present state output diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 68e3985651123..6b00ac94bc10f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -243,7 +243,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .Input(5, - "extra_add", + "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", "T", OpSchema::Optional) @@ -313,6 +313,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)", "M", OpSchema::Optional) + .Input(5, + "relative_position_bias", + "relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)" + " or (1, num_heads, sequence_length, total_sequence_length)", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, v_hidden_size)", @@ -668,5 +674,41 @@ ONNX_MS_OPERATOR_SET_SCHEMA( RestorePaddingTypeAndShapeInference(ctx); })); +constexpr const char* GatedRelativePositionBias_ver1_doc = R"DOC( + query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) + gate_u, gate_r = torch.sigmoid( + self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0 + rel_pos_bias = gate_u_1 * rel_pos +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GatedRelativePositionBias, 1, + OpSchema() + .SetDoc(GatedRelativePositionBias_ver1_doc) + .Attr("num_heads", "Number of attention heads", AttributeProto::INT) + .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T") + .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T") + .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T") + .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T") + .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T") + .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T") + .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + int64_t num_heads = getAttribute(ctx, "num_heads", -1L); + if (hasInputShape(ctx, 0)) { + auto& query_layer_shape = getInputShape(ctx, 0); + TensorShapeProto output_shape; + *output_shape.add_dim() = query_layer_shape.dim(0); + output_shape.add_dim()->set_dim_value(num_heads); + *output_shape.add_dim() = query_layer_shape.dim(1); + *output_shape.add_dim() = query_layer_shape.dim(1); + updateOutputShape(ctx, 0, output_shape); + } + })); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index a511d01fe1624..bd8469909fe7f 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -81,6 +81,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBias); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); @@ -171,6 +172,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 6111afbd5d817..91e4f5d8ff81a 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -1140,7 +1140,7 @@ where value of each element is the end position, or valid length of actual seque left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. -Current version does not support past/present, extra_add and qkv_hidden_sizes. +Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. )DOC"; @@ -1202,7 +1202,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(18, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "Q", OpSchema::Optional) - .Input(19, "extra_add", + .Input(19, "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "S", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "Q") diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index b4a92019992b5..c0a75fc50b07e 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -198,12 +198,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, - extra_add_qk, + relative_position_bias, parameters, max_threads_per_block, past_seq_len); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 2490789dd31a2..f12e080adf30a 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -145,7 +145,7 @@ struct ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) = 0; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index 63bae80c51a67..af93808248032 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -401,15 +401,15 @@ class DmlOperatorAttention : public DmlOperator void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) { *isSupported = false; - // Fall back to CPU if input 'past' and 'extra_add' is present because there is no current use case for this. + // Fall back to CPU if input 'past' and 'relative_position_bias' is present because there is no current use case for this. // and it will make the implementation more complex. // Also fall back to CPU if output 'present' is present for same reason as above. if (context->GetInputCount() > 4 || context->GetOutputCount() > 1) { return; } - // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'extra_add' is present - // because input 'mask_index', 'past', and 'extra_add' all are optional. + // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'relative_position_bias' is present + // because input 'mask_index', 'past', and 'relative_position_bias' all are optional. if (context->IsInputValid(4) || context->IsInputValid(5)) { return; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index b772fe95d6ecc..30be10ea7e15f 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -563,12 +563,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { return g_host_cpu.AttentionBase__CheckInputs(this, input_shape, weights_shape, bias_shape, - mask_index, past, extra_add_qk, parameters, + mask_index, past, relative_position_bias, parameters, max_threads_per_block, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, const Tensor* past, int batch_size, int head_size, diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 245ea9322ad61..342d43306e699 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -337,7 +337,7 @@ def create_attention_node( # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights. if self.use_multi_head_attention: if add_qk_str is not None: - logger.debug("MultiHeadAttention does not support extra_add_qk: cannot fuse the attention.") + logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.") return None attention_inputs = [ diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 96c22b5894c60..d81a062f1060e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -319,7 +319,7 @@ def match_parent_path( self, node, parent_op_types, - parent_input_index, + parent_input_index=None, output_name_to_node=None, return_indice=None, ): @@ -339,7 +339,8 @@ def match_parent_path( Returns: parents: a list of matched parent node. """ - assert len(parent_input_index) == len(parent_op_types) + if parent_input_index is not None: + assert len(parent_input_index) == len(parent_op_types) if output_name_to_node is None: output_name_to_node = self.output_name_to_node() @@ -350,16 +351,19 @@ def match_parent_path( matched_parent = self.match_parent( current_node, op_type, - parent_input_index[i], + parent_input_index[i] if parent_input_index is not None else None, output_name_to_node, exclude=[], return_indice=return_indice, ) if matched_parent is None: - logger.debug( - f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", - stack_info=True, - ) + if parent_input_index is not None: + logger.debug( + f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", + stack_info=True, + ) + else: + logger.debug(f"Failed to match index={i} op_type={op_type}", stack_info=True) return None matched_parents.append(matched_parent) diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py index dc8f6810914a7..85e510a828990 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py @@ -172,8 +172,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add = k_nodes[-2] matmul = k_nodes[-1] - extra_add_qk_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0]) - if extra_add_qk_nodes is None: + relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0]) + if relative_position_bias_nodes is None: return if matmul.input[0] == root_input: @@ -189,7 +189,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.hidden_size, root_input, attention_last_node.output[0], - extra_add_qk_nodes[0].input[0], + relative_position_bias_nodes[0].input[0], ) if new_node is None: return diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index fb1d8fcfe451a..0ea85dfdaba4f 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -59,7 +59,7 @@ static void RunAttentionTest( const bool disable_cuda = false, const bool disable_rocm = false, std::vector qkv_sizes = {}, - const std::vector& extra_add_data = {}, + const std::vector& relative_position_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false) { @@ -199,12 +199,12 @@ static void RunAttentionTest( } } - std::vector extra_add_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; - if (extra_add_data.size() > 0) { + std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + if (relative_position_bias_data.size() > 0) { if (use_float16) { - tester.AddInput("extra_add_qk", extra_add_data_dims, ToFloat16(extra_add_data)); + tester.AddInput("relative_position_bias", relative_position_bias_data_dims, ToFloat16(relative_position_bias_data)); } else { - tester.AddInput("extra_add_qk", extra_add_data_dims, extra_add_data); + tester.AddInput("relative_position_bias", relative_position_bias_data_dims, relative_position_bias_data); } } else { if (use_float16) { @@ -264,7 +264,7 @@ static void RunAttentionTest( const bool disable_cuda = false, const bool disable_rocm = false, const std::vector qkv_sizes = {}, - const std::vector& extra_add_data = {}, + const std::vector& relative_position_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false) { @@ -272,13 +272,13 @@ static void RunAttentionTest( batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data, + disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias_data, kv_sequence_length, past_present_share_buffer, use_scale); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data, + disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias_data, kv_sequence_length, past_present_share_buffer, use_scale); } @@ -390,7 +390,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { 0, false, false, disable_rocm, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1ExtraAdd) { +TEST(AttentionTest, AttentionBatch1RelativePositionBias) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -414,7 +414,7 @@ TEST(AttentionTest, AttentionBatch1ExtraAdd) { std::vector mask_index_data = {2L}; - std::vector extra_add_qk = { + std::vector relative_position_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; std::vector output_data = { @@ -427,10 +427,10 @@ TEST(AttentionTest, AttentionBatch1ExtraAdd) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_qk); + 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias); } -TEST(AttentionTest, AttentionBatch2ExtraAdd) { +TEST(AttentionTest, AttentionBatch2RelativePositionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -456,7 +456,7 @@ TEST(AttentionTest, AttentionBatch2ExtraAdd) { std::vector mask_index_data = {2L, 2L}; - std::vector extra_add_qk = { + std::vector relative_position_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f, 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; @@ -472,7 +472,7 @@ TEST(AttentionTest, AttentionBatch2ExtraAdd) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_qk); + 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias); } TEST(AttentionTest, AttentionBatch1_Float16) { @@ -1709,7 +1709,7 @@ TEST(AttentionTest, AttentionWithNormFactor) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, {} /*qkv_sizes*/, - {} /*extra_add_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, + {} /*relative_position_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/); } diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index cf14a7f9918c6..5257fbdb08809 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -278,7 +278,7 @@ TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { test_qorder.AddInput("scale_values_gemm", {}, {attn_out_scale}, true); test_qorder.AddInput("mask_index", {batch_size, sequence_len}, input_mask.data(), input_mask.size()); test_qorder.AddOptionalInputEdge(); // past - test_qorder.AddOptionalInputEdge(); // extra_add + test_qorder.AddOptionalInputEdge(); // relative_position_bias test_qorder.AddOutput("output", {batch_size, sequence_len, hidden_size}, attn_out_q8.data(), attn_out_q8.size()); diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc index 7722291bee653..ba0299e4f3808 100644 --- a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc +++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc @@ -10,9 +10,9 @@ namespace onnxruntime { namespace test { static void RunRelativePositionBiasTest( - const std::vector& bias_table, // Shape = [num_buckets, num_heads] - const std::vector& sequence_length, // Shape = [1] - const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length] + const std::vector& bias_table, // Shape = [num_buckets, num_heads] + const std::vector& sequence_length, // Shape = [1] + const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length] int max_distance, int num_buckets, int num_heads, @@ -155,5 +155,264 @@ TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16_No_Bidirectional) { true); } +/***************Following scripts is used to generate test data, for your reference************* +import torch + +batch_size = 2 +num_heads = 2 +seq_len = 3 +head_size = 4 +D = 8 + +def dim_string_of(tensor): + return "{" + ", ".join([str(d) for d in tensor.shape]) + "}" + +def value_string_of(tensor): + arr = tensor.flatten().numpy() + lines = ["f, ".join([str(v) for v in arr[i : min(i+8, arr.size)]]) for i in range(0, arr.size, 8)] + return "{\n " + "f,\n ".join(lines) + "f}" + +def print_tensor(name, tensor): + print(f"const std::vector {name}_dim = {dim_string_of(tensor)};") + print(f"const std::vector {name} = {value_string_of(tensor)};") + +torch.manual_seed(0) +query_layer = torch.rand(batch_size, seq_len, num_heads * head_size) +query_bias = torch.rand(num_heads * head_size) +rel_pos = torch.rand(1, num_heads, seq_len, seq_len) +weight = torch.rand(head_size, D) +bias = torch.rand(D) +eco_a = torch.rand(1, num_heads, 1, 1) + +qw = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) +gate_u,gate_r = torch.sigmoid( + (torch.matmul(qw, weight) + bias).view(batch_size, num_heads, seq_len,2, D//2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) +gate_u_1 = gate_u * (gate_r * eco_a - 1.0) + 2.0 +output = gate_u_1 * rel_pos + +# output for test case +print(f"const int batch_size = {batch_size};") +print(f"const int num_heads = {num_heads};") +print(f"const int seq_len = {seq_len};") +print(f"const int head_size = {head_size};") +print(f"const int D = {D};") + +print_tensor("query_layer", query_layer) +print_tensor("query_bias", query_bias) +print_tensor("rel_pos", rel_pos) +print_tensor("weight", weight) +print_tensor("bias", bias) +print_tensor("eco_a", eco_a) +print_tensor("output", output) +****************/ + +// .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T") +// .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T") +// .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T") +// .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T") +// .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T") +// .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T") +// .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T") +static void RunGatedRelativePositionBiasTest( + const std::vector& query_layer, + const std::vector& query_bias, + const std::vector& rel_pos, + const std::vector& weight, + const std::vector& bias, + const std::vector& eco_a, + const std::vector& output, + int batch_size, + int seq_len, + int num_heads, + int head_size, + int D, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda) { + OpTester tester("GatedRelativePositionBias", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + + std::vector query_layer_dims = {batch_size, seq_len, num_heads * head_size}; + std::vector query_bias_dims = {num_heads * head_size}; + std::vector rel_pos_dims = {1, num_heads, seq_len, seq_len}; + std::vector weight_dims = {head_size, D}; + std::vector bias_dims = {D}; + std::vector eco_a_dims = {1, num_heads, 1, 1}; + std::vector output_dims = {batch_size, num_heads, seq_len, seq_len}; + + if (use_float16) { + tester.AddInput("query_layer", query_layer_dims, ToFloat16(query_layer)); + tester.AddInput("query_bias", query_bias_dims, ToFloat16(query_bias)); + tester.AddInput("rel_pos", rel_pos_dims, ToFloat16(rel_pos)); + tester.AddInput("weight", weight_dims, ToFloat16(weight)); + tester.AddInput("bias", bias_dims, ToFloat16(bias)); + tester.AddInput("eco_a", eco_a_dims, ToFloat16(eco_a)); + tester.AddOutput("output", output_dims, ToFloat16(output)); + } else { + tester.AddInput("query_layer", query_layer_dims, query_layer); + tester.AddInput("query_bias", query_bias_dims, query_bias); + tester.AddInput("rel_pos", rel_pos_dims, rel_pos); + tester.AddInput("weight", weight_dims, weight); + tester.AddInput("bias", bias_dims, bias); + tester.AddInput("eco_a", eco_a_dims, eco_a); + tester.AddOutput("output", output_dims, output); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) { + constexpr int batch_size = 1; + constexpr int num_heads = 2; + constexpr int seq_len = 3; + constexpr int head_size = 4; + constexpr int D = 8; + const std::vector query_layer_dim = {1, 3, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f}; + const std::vector rel_pos_dim = {1, 2, 3, 3}; + const std::vector rel_pos = { + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f, + 0.27825248f, 0.4819588f}; + const std::vector weight_dim = {4, 8}; + const std::vector weight = { + 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f, 0.593172f, 0.112347245f, + 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, 0.774486f, 0.43689132f, + 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, 0.69650495f, 0.9142747f, + 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f, 0.034022927f, 0.94424623f}; + const std::vector bias_dim = {8}; + const std::vector bias = { + 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f, 0.6922781f, 0.20384824f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.68329567f, 0.75285405f}; + const std::vector output_dim = {1, 2, 3, 3}; + const std::vector output = { + 0.29608122f, 0.45416728f, 0.25361493f, 0.053390637f, 0.3503264f, 1.5650483f, 1.2171557f, 1.2495192f, + 0.88587445f, 0.42708054f, 1.0246648f, 0.05810945f, 0.2430356f, 0.4244021f, 1.428723f, 1.3902748f, + 0.48772895f, 0.84479123f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, true); +} + +TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) { + constexpr int batch_size = 2; + constexpr int num_heads = 2; + constexpr int seq_len = 3; + constexpr int head_size = 4; + constexpr int D = 8; + const std::vector query_layer_dim = {2, 3, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f, + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f, + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.27825248f, 0.4819588f, 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f}; + const std::vector rel_pos_dim = {1, 2, 3, 3}; + const std::vector rel_pos = { + 0.593172f, 0.112347245f, 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, + 0.774486f, 0.43689132f, 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, + 0.69650495f, 0.9142747f}; + const std::vector weight_dim = {4, 8}; + const std::vector weight = { + 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f, 0.034022927f, 0.94424623f, + 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f, 0.6922781f, 0.20384824f, + 0.68329567f, 0.75285405f, 0.8579358f, 0.6869556f, 0.005132377f, 0.17565155f, 0.7496575f, 0.6046507f, + 0.10995799f, 0.21209025f, 0.97037464f, 0.83690894f, 0.28198743f, 0.3741576f, 0.023700953f, 0.49101293f}; + const std::vector bias_dim = {8}; + const std::vector bias = { + 0.123470545f, 0.11432165f, 0.4724502f, 0.5750725f, 0.29523486f, 0.7966888f, 0.19573045f, 0.95368505f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.84264994f, 0.07835853f}; + const std::vector output_dim = {2, 2, 3, 3}; + const std::vector output = { + 1.0928818f, 0.20699267f, 0.28273466f, 0.44534987f, 1.3380982f, 1.2917475f, 0.3755537f, 1.1995932f, + 1.4270226f, 0.47112367f, 0.5597638f, 0.6641071f, 0.87368786f, 1.0569134f, 0.12367705f, 0.34158573f, + 0.75108063f, 0.98591405f, 1.0929474f, 0.2070051f, 0.28275162f, 0.4451845f, 1.3376014f, 1.2912678f, + 0.37552574f, 1.1995038f, 1.4269164f, 0.47112313f, 0.5597632f, 0.6641063f, 0.87367094f, 1.056893f, + 0.12367466f, 0.34158388f, 0.7510766f, 0.98590875f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, false); +} + +TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) { + constexpr int batch_size = 2; + constexpr int num_heads = 2; + constexpr int seq_len = 5; + constexpr int head_size = 4; + constexpr int D = 4; + const std::vector query_layer_dim = {2, 5, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f, + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f, + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f, + 0.27825248f, 0.4819588f, 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f, + 0.593172f, 0.112347245f, 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, + 0.774486f, 0.43689132f, 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, + 0.69650495f, 0.9142747f, 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.034022927f, 0.94424623f, 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f}; + const std::vector rel_pos_dim = {1, 2, 5, 5}; + const std::vector rel_pos = { + 0.6922781f, 0.20384824f, 0.68329567f, 0.75285405f, 0.8579358f, 0.6869556f, 0.005132377f, 0.17565155f, + 0.7496575f, 0.6046507f, 0.10995799f, 0.21209025f, 0.97037464f, 0.83690894f, 0.28198743f, 0.3741576f, + 0.023700953f, 0.49101293f, 0.123470545f, 0.11432165f, 0.4724502f, 0.5750725f, 0.29523486f, 0.7966888f, + 0.19573045f, 0.95368505f, 0.84264994f, 0.07835853f, 0.37555784f, 0.5225613f, 0.57295054f, 0.61858714f, + 0.69621414f, 0.5299501f, 0.25603563f, 0.7365945f, 0.02037555f, 0.20364666f, 0.37483507f, 0.25644332f, + 0.32508332f, 0.09018916f, 0.39364243f, 0.6068782f, 0.17426711f, 0.47434032f, 0.8579254f, 0.44859987f, + 0.5138961f, 0.45686555f}; + const std::vector weight_dim = {4, 4}; + const std::vector weight = { + 0.6011907f, 0.81791973f, 0.9736231f, 0.81752795f, 0.97470677f, 0.46383917f, 0.050839245f, 0.2629614f, + 0.8404526f, 0.49675876f, 0.25147682f, 0.11684412f, 0.032073975f, 0.0779959f, 0.39858162f, 0.774203f}; + const std::vector bias_dim = {4}; + const std::vector bias = { + 0.77032053f, 0.017784059f, 0.811891f, 0.10874528f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.39429486f, 0.29726368f}; + const std::vector output_dim = {2, 2, 5, 5}; + const std::vector output = { + 0.9534052f, 0.28073975f, 0.9410346f, 1.0368304f, 1.181549f, 0.94923383f, 0.0070919087f, 0.24271497f, + 1.0358753f, 0.8355051f, 0.15224966f, 0.29366368f, 1.3435968f, 1.158798f, 0.3904445f, 0.5147038f, + 0.03260383f, 0.67545396f, 0.16985025f, 0.15726471f, 0.64280313f, 0.7824283f, 0.40168867f, 1.0839535f, + 0.26630563f, 1.2391479f, 1.0948771f, 0.101813294f, 0.48797214f, 0.6789776f, 0.7492329f, 0.8089107f, + 0.91042155f, 0.6930023f, 0.3348113f, 0.95611423f, 0.026447866f, 0.2643374f, 0.48654333f, 0.3328685f, + 0.4239932f, 0.117630124f, 0.5134121f, 0.7915271f, 0.22728965f, 0.61497897f, 1.1122944f, 0.5816067f, + 0.6662628f, 0.59232306f, 0.95294285f, 0.2806036f, 0.9405782f, 1.0363276f, 1.1809759f, 0.95289487f, + 0.007119261f, 0.24365108f, 1.0398705f, 0.83872753f, 0.15201466f, 0.29321042f, 1.3415229f, 1.1570094f, + 0.38984182f, 0.51978874f, 0.032925934f, 0.682127f, 0.17152825f, 0.15881838f, 0.6571103f, 0.79984313f, + 0.4106292f, 1.1080796f, 0.2722329f, 1.2398669f, 1.0955123f, 0.101872355f, 0.4882552f, 0.6793715f, + 0.7427765f, 0.8019401f, 0.9025762f, 0.6870305f, 0.33192614f, 0.9568577f, 0.026468432f, 0.26454294f, + 0.48692167f, 0.33312735f, 0.4217717f, 0.117013805f, 0.5107221f, 0.78737986f, 0.22609876f, 0.6166911f, + 1.1153911f, 0.5832259f, 0.6681177f, 0.59397215f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, false); +} + } // namespace test } // namespace onnxruntime