diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 6eb315c59bc80..b118161b1c45a 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -15,6 +15,10 @@ set(contrib_ops_excluded_files
"bert/fast_gelu_impl.h"
"bert/fast_gelu.cc"
"bert/fast_gelu.h"
+ "bert/relative_attn_bias.cc"
+ "bert/relative_attn_bias.h"
+ "bert/relative_attn_bias_impl.cu"
+ "bert/relative_attn_bias_impl.h"
"bert/skip_layer_norm.cc"
"bert/skip_layer_norm.h"
"bert/skip_layer_norm_impl.cu"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 8cd6d4c9e26f1..f01a7ab14a61e 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -30,6 +30,7 @@ Do not modify directly.*
* com.microsoft.FusedConv
* com.microsoft.FusedGemm
* com.microsoft.FusedMatMul
+ * com.microsoft.GatedRelativePositionBias
* com.microsoft.GatherND
* com.microsoft.Gelu
* com.microsoft.GemmFastGelu
@@ -152,7 +153,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
-extra_add (optional) : T
+relative_position_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
@@ -1608,6 +1609,58 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.GatedRelativePositionBias**
+
+ query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2)
+ gate_u, gate_r = torch.sigmoid(
+ self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False)
+ ).chunk(2, dim=-1)
+ gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0
+ rel_pos_bias = gate_u_1 * rel_pos
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- num_heads : int (required)
+- Number of attention heads
+
+
+#### Inputs
+
+
+- query_layer : T
+- tensor with shape (batch_size, seq_len, num_heads x head_size)
+- query_bias : T
+- 1-d tensor with shape (num_heads x head_size)
+- rel_pos : T
+- tensor with shape (1, num_head, seq_len, seq_len)
+- weight : T
+- gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2
+- bias : T
+- bias for the gated_ur_linear, shape (D)
+- eco_a : T
+- tensor of shape (1, num_heads, 1, 1)
+
+
+#### Outputs
+
+
+- output : T
+- output tensor with shape (batch_size, num_heads, seq_len, seq_len)
+
+
+#### Type Constraints
+
+
+- T : tensor(float), tensor(float16)
+- Constrain input and output types to float tensors.
+
+
+
### **com.microsoft.GatherND**
Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather
@@ -2222,7 +2275,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (2 - 5)
+#### Inputs (2 - 6)
- query : T
@@ -2235,6 +2288,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
- key_padding_mask (optional) : M
- Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
+- relative_position_bias (optional) : T
+- relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
#### Outputs
@@ -3221,7 +3276,7 @@ This version of the operator has been available since version 1 of the 'com.micr
left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by
the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past
and present state are optional. Present state could appear in output even when past state is not in input.
- Current version does not support past/present, extra_add and qkv_hidden_sizes.
+ Current version does not support past/present, relative_position_bias and qkv_hidden_sizes.
TODO: Support them if needed in the future.
#### Version
@@ -3286,7 +3341,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : Q
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
-extra_add (optional) : S
+relative_position_bias (optional) : S
additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 286cad61d599f..29043259064a5 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -417,7 +417,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
-|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
+|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
@@ -785,7 +785,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
-|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
+|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
@@ -803,6 +803,7 @@ Do not modify directly.*
|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
@@ -810,11 +811,11 @@ Do not modify directly.*
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
-|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
-|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* extra_add:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
+|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedLongformerAttention|*in* input:**Q**
*in* scale_input:**S**
*in* weight:**Q**
*in* scale_weight:**S**
*in* bias:**S**
*in* scale_bias:**S**
*in* scale_qkv_gemm:**S**
*in* mask:**F**
*in* global_weight:**Q**
*in* scale_global_weight:**S**
*in* global_bias:**S**
*in* scale_global_gemm:**S**
*in* global:**G**
*in* scale_output:**S**
*out* output:**Q**|1+|**F** = tensor(float16)
**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
@@ -1159,7 +1160,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
-|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index 47db3fe558ce8..6aa0e726afe1b 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -198,7 +198,7 @@ Status Attention::Compute(OpKernelContext* context) const {
const Tensor* mask_index = context->Input(3);
const Tensor* past = context->Input(4);
- const Tensor* extra_add_qk = context->Input(5);
+ const Tensor* relative_position_bias = context->Input(5);
const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_);
@@ -208,7 +208,7 @@ Status Attention::Compute(OpKernelContext* context) const {
bias->Shape(),
mask_index,
past,
- extra_add_qk,
+ relative_position_bias,
¶meters));
const int batch_size = parameters.batch_size;
@@ -331,7 +331,7 @@ Status Attention::Compute(OpKernelContext* context) const {
return ApplyAttention(Q, K, V, mask_index, past, output,
batch_size, sequence_length,
parameters.head_size, parameters.v_head_size, parameters.v_hidden_size,
- extra_add_qk, context);
+ relative_position_bias, context);
}
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
index affe7cab1d858..e75f68ea53c7c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
@@ -12,7 +12,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const TensorShape& bias_shape,
const Tensor*& mask_index,
const Tensor* past,
- const Tensor* extra_add_qk,
+ const Tensor* relative_position_bias,
void* parameters,
const Tensor* past_seq_len) const {
// Abbreviation and Meanings:
@@ -37,7 +37,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
// bias (Q/K/V) : (D + D + D_v)
// mask_index : see below
// past (K/V) : (2, B, N, P, H) or NULL
- // extra_add_qk : (B, N, S, T) or NULL
+ // relative_position_bias : (B, N, S, T) or NULL
// For mask_index, the following shapes are supported:
// NULL, (B, 1), (1, 1)
@@ -49,9 +49,9 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
// When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger
// than hidden dimension of Q, K and V.
- if (past != nullptr && extra_add_qk != nullptr) {
- // past is used on GPT-2 model with past state, we don't have a case for extra add qk yet
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and extra_add_qk");
+ if (past != nullptr && relative_position_bias != nullptr) {
+ // past is used on GPT-2 model with past state, we don't have a case for relative position bias yet
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and relative_position_bias");
}
const auto& dims = input_shape.GetDims();
@@ -191,34 +191,34 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
}
}
- if (extra_add_qk != nullptr) {
- const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims();
+ if (relative_position_bias != nullptr) {
+ const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
- if (extra_add_qk_dims.size() != 4) {
+ if (relative_position_bias_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'extra_add_qk' is expected to have 4 dimensions, got ",
- extra_add_qk_dims.size());
+ "Input 'relative_position_bias' is expected to have 4 dimensions, got ",
+ relative_position_bias_dims.size());
}
- if (extra_add_qk_dims[0] != batch_size) {
+ if (relative_position_bias_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ",
- extra_add_qk_dims[0]);
+ "Input 'relative_position_bias' dimension 0 should be same as batch_size, got ",
+ relative_position_bias_dims[0]);
}
- if (extra_add_qk_dims[1] != num_heads_) {
+ if (relative_position_bias_dims[1] != num_heads_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ",
- extra_add_qk_dims[1]);
+ "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ",
+ relative_position_bias_dims[1]);
}
- if (extra_add_qk_dims[2] != sequence_length) {
+ if (relative_position_bias_dims[2] != sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ",
- extra_add_qk_dims[2]);
+ "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ",
+ relative_position_bias_dims[2]);
}
- if (extra_add_qk_dims[3] != total_sequence_length) {
+ if (relative_position_bias_dims[3] != total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'extra_add_qk' dimension 3 should be same as total_sequence_length, got ",
- extra_add_qk_dims[3]);
+ "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ",
+ relative_position_bias_dims[3]);
}
}
@@ -320,7 +320,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const TensorShape& bias_shape,
const Tensor*& mask_index,
const Tensor* past,
- const Tensor* extra_add_qk,
+ const Tensor* relative_position_bias,
void* parameters,
const int max_threads_per_block,
const Tensor* past_seq_len) const {
@@ -328,7 +328,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
- return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, extra_add_qk, parameters, past_seq_len);
+ return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, relative_position_bias, parameters, past_seq_len);
}
Tensor* AttentionBase::GetPresent(OpKernelContext* context,
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index 2c49f196d52d8..2e077da2853d0 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -18,7 +18,7 @@ class AttentionBase {
const TensorShape& bias_shape,
const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr.
const Tensor* past,
- const Tensor* extra_add_qk,
+ const Tensor* relative_position_bias,
void* parameters,
const int max_threads_per_block, // for CUDA
const Tensor* past_seq_len = nullptr) const;
@@ -61,7 +61,7 @@ class AttentionBase {
const TensorShape& bias_shape,
const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr.
const Tensor* past,
- const Tensor* extra_add_qk,
+ const Tensor* relative_position_bias,
void* parameters,
const Tensor* past_seq_len = nullptr) const;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
index 0185fa9ea09a0..70d71ffb6ee40 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
@@ -19,18 +19,18 @@ class AttentionCPUBase : public AttentionBase {
: AttentionBase(info, require_same_hidden_size) {}
template
- Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
- const T* K, // K data with shape BxNxSxH
- const T* V, // V value with size BxNxSxH_v
- const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
- const Tensor* past, // past state
- Tensor* output, // output tensor
- int batch_size, // batch size (B)
- int sequence_length, // sequence length (S)
- int qk_head_size, // head size of Q or K (H)
- int v_head_size, // head size of V (H_v)
- int v_hidden_size, // hidden size of V (D_v)
- const Tensor* extra_add_qk, // extra add in QK. Its size is BxNxSxT
+ Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
+ const T* K, // K data with shape BxNxSxH
+ const T* V, // V value with size BxNxSxH_v
+ const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
+ const Tensor* past, // past state
+ Tensor* output, // output tensor
+ int batch_size, // batch size (B)
+ int sequence_length, // sequence length (S)
+ int qk_head_size, // head size of Q or K (H)
+ int v_head_size, // head size of V (H_v)
+ int v_hidden_size, // hidden size of V (D_v)
+ const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT
OpKernelContext* context) const {
const int kv_sequence_length = sequence_length;
@@ -67,16 +67,16 @@ class AttentionCPUBase : public AttentionBase {
const T* past_data = past != nullptr ? past->Data() : nullptr;
T* present_data = present != nullptr ? present->MutableData() : nullptr;
- const T* extra_add_qk_data = nullptr;
- if (extra_add_qk != nullptr) {
- extra_add_qk_data = extra_add_qk->Data();
+ const T* relative_position_bias_data = nullptr;
+ if (relative_position_bias != nullptr) {
+ relative_position_bias_data = relative_position_bias->Data();
}
ComputeAttentionProbs(static_cast(attention_probs), Q, K,
mask_index_data, mask_index_dims, static_cast(mask_data), has_unidirectional,
batch_size, sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size,
- past_data, present_data, tp, extra_add_qk_data);
+ past_data, present_data, tp, relative_position_bias_data);
// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
@@ -112,7 +112,7 @@ class AttentionCPUBase : public AttentionBase {
const T* past, // past state
T* present, // present state
ThreadPool* tp, // thread pool
- const T* extra_add_qk_data // extra add matrix with shape BxNxSxT
+ const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT
) const {
const int total_sequence_length = past_sequence_length + sequence_length; // T = P + L
const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H
@@ -175,9 +175,9 @@ class AttentionCPUBase : public AttentionBase {
}
}
- if (extra_add_qk_data != nullptr) {
+ if (relative_position_bias_data != nullptr) {
for (int j = 0; j < sequence_length * total_sequence_length; j++) {
- output[j] += extra_add_qk_data[output_offset + j];
+ output[j] += relative_position_bias_data[output_offset + j];
}
}
}
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index 8c3af05972c95..ee1720b9f43bb 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -17,6 +17,7 @@ Status CheckInputs(const T* query,
const T* value,
const T* bias,
const T* key_padding_mask,
+ const T* relative_position_bias,
void* parameters,
int num_heads,
float mask_filter_value,
@@ -26,6 +27,7 @@ Status CheckInputs(const T* query,
// value (V) : (B, L, D_v)
// bias (Q/K/V) : (D + D + D_v)
// key_padding_mask (K/V) : (B) or (B, L) or None
+ // relative_position_bias : (B, 1, S, L)
// When packed kv is used:
// key (K) : (B, L, N, 2, H)
// value (V) : None
@@ -120,6 +122,36 @@ Status CheckInputs(const T* query,
v_hidden_size = static_cast(value_dims[2]);
}
+ if (relative_position_bias != nullptr) {
+ const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
+
+ if (relative_position_bias_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' is expected to have 4 dimensions, got ",
+ relative_position_bias_dims.size());
+ }
+ if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ",
+ relative_position_bias_dims[0]);
+ }
+ if (relative_position_bias_dims[1] != num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ",
+ relative_position_bias_dims[1]);
+ }
+ if (relative_position_bias_dims[2] != sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ",
+ relative_position_bias_dims[2]);
+ }
+ if (relative_position_bias_dims[3] != kv_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ",
+ relative_position_bias_dims[3]);
+ }
+ }
+
if (parameters != nullptr) {
AttentionParameters* output_parameters = reinterpret_cast(parameters);
output_parameters->batch_size = batch_size;
diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
index 64c17b7767e4f..e7df84c1b0066 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
@@ -160,7 +160,7 @@ Status QAttention::Compute(OpKernelContext* context) const {
bias->Shape(),
mask_index,
past_tensor,
- nullptr, // extra_add_qk
+ nullptr, // relative_position_bias
nullptr // parameters
));
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index 4a6d2dc137139..1ab89b525eae5 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -59,7 +59,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
const Tensor* bias = context->Input(2);
const Tensor* mask_index = context->Input(3);
const Tensor* past = context->Input(kPastInputIndex);
- const Tensor* extra_add_qk = context->Input(5);
+ const Tensor* relative_position_bias = context->Input(5);
const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex);
auto& device_prop = GetDeviceProp();
@@ -69,7 +69,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
bias->Shape(),
mask_index,
past,
- extra_add_qk,
+ relative_position_bias,
¶meters,
device_prop.maxThreadsPerBlock,
past_seq_len));
@@ -105,7 +105,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = !disable_fused_runner_ &&
(nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
- nullptr == extra_add_qk &&
+ nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
@@ -125,7 +125,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
(nullptr == mask_index || is_mask_1d_seq_len) &&
nullptr == past &&
nullptr == present &&
- nullptr == extra_add_qk &&
+ nullptr == relative_position_bias &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
@@ -151,7 +151,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
nullptr == mask_index && // TODO: support 1D mask
nullptr == past &&
nullptr == present &&
- nullptr == extra_add_qk &&
+ nullptr == relative_position_bias &&
(sizeof(T) == 2 || // sequence length threshold is 0 in FP16
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
@@ -203,7 +203,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data();
data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims();
data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data());
- data.extra_add_qk = (nullptr == extra_add_qk) ? nullptr : reinterpret_cast(extra_add_qk->Data());
+ data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data());
data.workspace = reinterpret_cast(work_space.get());
data.output = reinterpret_cast(output->MutableData());
data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData());
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 8c7ef9f919519..fcf86637350b6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -665,7 +665,7 @@ Status QkvToContext(
T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax.
ORT_RETURN_IF_ERROR(
ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads,
- mask_index, nullptr, data.extra_add_qk, scratch1, scratch2,
+ mask_index, nullptr, data.relative_position_bias, scratch1, scratch2,
parameters.is_unidirectional, scale, mask_dimension,
parameters.max_sequence_length, use_persistent_softmax,
persistent_softmax_workspace, mask_filter_value));
@@ -675,10 +675,10 @@ Status QkvToContext(
const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr;
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
- mask_index, mask_start, data.extra_add_qk, scratch1, scratch2, parameters.is_unidirectional));
+ mask_index, mask_start, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional));
} else { // no mask
ORT_RETURN_IF_ERROR(
- ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.extra_add_qk,
+ ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias,
scratch1, scratch2, parameters.is_unidirectional));
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
index d98a0380c479b..2ecda71479c52 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
@@ -41,7 +41,7 @@ struct AttentionData {
const int* mask_index;
gsl::span mask_index_dims;
const T* past;
- const T* extra_add_qk;
+ const T* relative_position_bias;
T* workspace;
T* output;
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
index 16b3cf053b586..92851c446d48f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
@@ -377,11 +377,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
float thread_data = -CUDART_INF_F;
if (threadIdx.x < all_sequence_length) {
- if (add_before_softmax == nullptr) {
- thread_data = float(input[index]) * rsqrt_head_size;
- } else {
- thread_data = float(input[index] + add_before_softmax[index]) * rsqrt_head_size;
- }
+ thread_data = float(input[index]) * rsqrt_head_size;
const int sequence_index = blockIdx.x % sequence_length;
if (is_unidirectional) {
@@ -412,6 +408,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
thread_data = -CUDART_INF_F;
}
}
+
+ if (add_before_softmax != nullptr) {
+ thread_data += float(add_before_softmax[index]);
+ }
}
if (skip_softmax) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index 93e5e59ed00ae..57a3a310a0dd6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -62,6 +62,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
const Tensor* value = context->Input(2);
const Tensor* bias = context->Input(3);
const Tensor* key_padding_mask = context->Input(4);
+ const Tensor* relative_position_bias = context->Input(5);
auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
@@ -70,6 +71,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
value,
bias,
key_padding_mask,
+ relative_position_bias,
¶meters,
num_heads_,
mask_filter_value_,
@@ -94,6 +96,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
bool use_fused_cross_attention = !disable_fused_cross_attention_ &&
nullptr == key_padding_mask &&
+ nullptr == relative_position_bias &&
(value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV
parameters.hidden_size == parameters.v_hidden_size &&
has_fused_cross_attention_kernel(sm, parameters.head_size,
@@ -112,6 +115,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
bool use_fused_runner = !disable_fused_runner_ &&
fused_cross_attention_kernel == nullptr &&
+ nullptr == relative_position_bias &&
value != nullptr && // fused runner requires packed qkv instead of packed kv
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
@@ -143,6 +147,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
!disable_memory_efficient_attention_ &&
is_long_sequence &&
nullptr == key_padding_mask && // TODO: support 1D mask
+ nullptr == relative_position_bias &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
#else
constexpr bool use_memory_efficient_attention = false;
@@ -171,7 +176,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data();
data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims();
data.past = nullptr;
- data.extra_add_qk = nullptr;
+ data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data());
data.workspace = reinterpret_cast(work_space.get());
data.output = reinterpret_cast(output->MutableData());
data.present = nullptr;
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
index af13efe0e2fbc..111fed04639e7 100644
--- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
@@ -3,7 +3,15 @@
#include "core/providers/cuda/cuda_common.h"
#include "relative_attn_bias.h"
+#include "core/common/safeint.h"
#include "relative_attn_bias_impl.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "contrib_ops/cuda/bert/add_bias_transpose.h"
+
+using namespace onnxruntime::cuda;
+using namespace ::onnxruntime::common;
+using namespace ONNX_NAMESPACE;
+
namespace onnxruntime {
namespace contrib {
@@ -20,7 +28,16 @@ namespace cuda {
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.InputMemoryType(OrtMemTypeCPUInput, 2) \
.TypeConstraint("T", DataTypeImpl::GetTensorType()), \
- RelPosAttnBias);
+ RelPosAttnBias); \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ GatedRelativePositionBias, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ GatedRelativePositionBias);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
@@ -69,6 +86,108 @@ Status RelPosAttnBias::ComputeInternal(OpKernelContext* context) const {
device_prop.maxThreadsPerBlock);
}
+template
+GatedRelativePositionBias::GatedRelativePositionBias(const OpKernelInfo& info) : CudaKernel(info) {
+ int64_t num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
+ num_heads_ = SafeInt(num_heads);
+}
+
+template
+Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) const {
+ const Tensor& query_tensor = *context->Input(0);
+ const Tensor& query_bias_tensor = *context->Input(1);
+ const Tensor& rel_pos_tensor = *context->Input(2);
+ const Tensor& weight_tensor = *context->Input(3);
+ const Tensor& bias_tensor = *context->Input(4);
+ const Tensor& eco_a_tensor = *context->Input(5);
+
+ const auto& query_dims = query_tensor.Shape().GetDims();
+ ORT_ENFORCE(query_dims.size() == 3);
+ ORT_ENFORCE(query_dims[2] > 0);
+ ORT_ENFORCE(query_dims[2] % num_heads_ == 0);
+ const auto batch_size = SafeInt(query_dims[0]);
+ const auto seq_len = SafeInt(query_dims[1]);
+ const auto head_size = SafeInt(query_dims[2] / num_heads_);
+
+ ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1);
+ ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]);
+
+ const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims();
+ ORT_ENFORCE(rel_pos_dims.size() == 4);
+ ORT_ENFORCE(rel_pos_dims[0] == 1);
+ ORT_ENFORCE(rel_pos_dims[1] == num_heads_);
+ ORT_ENFORCE(rel_pos_dims[2] == seq_len);
+ ORT_ENFORCE(rel_pos_dims[3] == seq_len);
+
+ const auto& weight_dims = weight_tensor.Shape().GetDims();
+ ORT_ENFORCE(weight_dims.size() == 2);
+ ORT_ENFORCE(weight_dims[0] == head_size);
+ ORT_ENFORCE((weight_dims[1] > 0) && (weight_dims[1] % 2 == 0));
+
+ ORT_ENFORCE(bias_tensor.Shape().NumDimensions() == 1);
+ ORT_ENFORCE(bias_tensor.Shape()[0] == weight_dims[1]);
+
+ const auto D = SafeInt(weight_dims[1]);
+
+ const auto& eco_a_dims = eco_a_tensor.Shape().GetDims();
+ ORT_ENFORCE(eco_a_dims.size() == 4);
+ ORT_ENFORCE(eco_a_dims[0] == 1);
+ ORT_ENFORCE(eco_a_dims[1] == num_heads_);
+ ORT_ENFORCE(eco_a_dims[2] == 1);
+ ORT_ENFORCE(eco_a_dims[3] == 1);
+
+ Tensor* output = context->Output(0, {batch_size, num_heads_, seq_len, seq_len});
+
+ auto& device_prop = GetDeviceProp();
+ cublasHandle_t cublas = GetCublasHandle(context);
+
+ typedef typename ToCudaType::MappedType CudaT;
+ const auto BNS = batch_size * num_heads_ * seq_len;
+ const size_t elements_in_query = (size_t)BNS * (size_t)head_size;
+ const size_t elements_after_gemm = (size_t)BNS *(size_t)D;
+ size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0);
+ auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream());
+
+ // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
+ constexpr int format = 1;
+ constexpr int total_maxtrix = 1;
+ constexpr int num_matrix_to_transpose = 1;
+ LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock,
+ batch_size, seq_len, num_heads_, head_size,
+ reinterpret_cast(query_tensor.template Data()),
+ reinterpret_cast(query_bias_tensor.template Data()),
+ reinterpret_cast(workspace.get()),
+ false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix);
+
+ // reuse output if possible
+ CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast(workspace.get()) + elements_in_query)
+ : reinterpret_cast(output->template MutableData());
+ int ld_gemm_output = max(seq_len, D);
+
+ const CudaT one = ToCudaType::FromFloat(1.0f);
+ const CudaT zero = ToCudaType::FromFloat(0.0f);
+
+ // ([b*n*s, h] * [h, D]), CUDA assumes col-major
+ CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
+ cublas, CUBLAS_OP_N, CUBLAS_OP_N,
+ D, BNS, head_size, &one,
+ reinterpret_cast(weight_tensor.template Data()), (int)D,
+ reinterpret_cast(workspace.get()), (int)head_size,
+ &zero, gemm_output, ld_gemm_output, device_prop));
+
+ auto status = LaunchGatedRelativePositionBiasKernel(
+ device_prop, Stream(context),
+ reinterpret_cast(output->template MutableData()),
+ reinterpret_cast(rel_pos_tensor.template Data()),
+ reinterpret_cast(gemm_output),
+ reinterpret_cast(bias_tensor.template Data()),
+ reinterpret_cast(eco_a_tensor.template Data()),
+ batch_size, num_heads_, seq_len, D, ld_gemm_output);
+
+ return status;
+}
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h
index b9674f6f35091..3bf4e730e29f9 100644
--- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h
@@ -22,6 +22,18 @@ class RelPosAttnBias final : public CudaKernel {
bool is_bidirectional_;
};
+template
+class GatedRelativePositionBias final : public CudaKernel {
+ public:
+ GatedRelativePositionBias(const OpKernelInfo& op_kernel_info);
+
+ Status ComputeInternal(OpKernelContext* ctx) const override;
+
+ private:
+ int num_heads_;
+};
+
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
index e333152cb5bcf..938496b058025 100644
--- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
@@ -36,7 +36,7 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias,
const bool is_bidirectional,
const int max_distance) {
const int head_id = blockIdx.x;
- for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) {
+ for (int seq_id = blockDim.x * blockIdx.y + threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) {
int row_id = seq_id / seq_len;
int col_id = seq_id % seq_len;
@@ -149,6 +149,122 @@ template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream,
const bool is_bidirectional,
const int max_threads_per_block);
+template
+__global__ void GatedRelativePositionBiasKernelSmallD(
+ T* output, // (batch_size, num_heads, seq_len, seq_len)
+ const T* rel_pos, // (1, num_heads, seq_len, seq_len)
+ const T* qw, // (batch_size, num_heads, seq_len, D)
+ const T* bias, // (D)
+ const T* eco_a, // (1, num_heads, 1, 1)
+ const int D,
+ const int ldqw) {
+ __shared__ float gate[1];
+
+ const int seq_len = gridDim.x;
+ const int num_heads = gridDim.y;
+ const int s = blockIdx.x;
+ const int n = blockIdx.y;
+ const int b = blockIdx.z;
+
+ rel_pos += ((int64_t)n * seq_len + s) * seq_len;
+ output += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * seq_len;
+ qw += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * ldqw;
+
+ float val = 0.0f;
+ if (threadIdx.x < D) {
+ val = (float)qw[threadIdx.x] + (bias ? (float)bias[threadIdx.x] : 0.0f);
+ }
+
+ float u = (threadIdx.x < D / 2) ? val : 0.0f;
+#pragma unroll
+ for (int offset = 16; offset > 0; offset /= 2) {
+ u += __shfl_down_sync(0xffffffff, u, offset);
+ }
+
+ float r = (threadIdx.x >= D / 2) ? val : 0.0f;
+#pragma unroll
+ for (int offset = 16; offset > 0; offset /= 2) {
+ r += __shfl_down_sync(0xffffffff, r, offset);
+ }
+
+ if (threadIdx.x == 0) {
+ u = 1.0f / (1.0f + expf(-u));
+ r = 1.0f / (1.0f + expf(-r));
+ gate[0] = u * (r * (float)eco_a[n] - 1.0f) + 2.0f;
+ }
+ __syncthreads();
+
+ for (int idx = threadIdx.x; idx < seq_len; idx += blockDim.x) {
+ output[idx] = (T)(gate[0] * (float)rel_pos[idx]);
+ }
+}
+
+template
+Status LaunchGatedRelativePositionBiasKernel(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ T* output,
+ const T* rel_pos,
+ const T* qw, // query * weight
+ const T* bias,
+ const T* eco_a,
+ const int batch_size,
+ const int num_heads,
+ const int seq_len,
+ const int D,
+ const int ldqw) {
+ ORT_ENFORCE(D <= 32 && D > 0 && (D % 2 == 0));
+ ORT_ENFORCE(ldqw == seq_len || ldqw == D);
+
+ int tpb = std::max(32, std::max(D, seq_len));
+ tpb = std::min(tpb, device_prop.maxThreadsPerBlock);
+
+ // round up tpb to power of 2
+ --tpb;
+ tpb |= (tpb >> 1);
+ tpb |= (tpb >> 2);
+ tpb |= (tpb >> 4);
+ tpb |= (tpb >> 8);
+ tpb |= (tpb >> 16);
+ tpb++;
+
+ dim3 block(tpb);
+ dim3 grid(seq_len, num_heads, batch_size);
+
+ GatedRelativePositionBiasKernelSmallD<<>>(
+ output, rel_pos, qw, bias, eco_a, D, ldqw);
+
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template Status LaunchGatedRelativePositionBiasKernel(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ float* output,
+ const float* rel_pos,
+ const float* qw,
+ const float* bias,
+ const float* eco_a,
+ const int batch_size,
+ const int num_heads,
+ const int seq_len,
+ const int D,
+ const int ldqw);
+
+template Status LaunchGatedRelativePositionBiasKernel(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ half* output,
+ const half* rel_pos,
+ const half* qw,
+ const half* bias,
+ const half* eco_a,
+ const int batch_size,
+ const int num_heads,
+ const int seq_len,
+ const int D,
+ const int ldqw);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
index 5a1a229ab6077..5c7c98f55f3f5 100644
--- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h
@@ -22,6 +22,21 @@ Status LaunchRelPosAttnBiasKernel(
const int max_threads_per_block
);
+template
+Status LaunchGatedRelativePositionBiasKernel(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ T* output,
+ const T* rel_pos,
+ const T* qw, // from query * weight
+ const T* bias,
+ const T* eco_a,
+ const int batch_size,
+ const int num_heads,
+ const int seq_len,
+ const int D,
+ const int ldqw);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index a239e528af148..1254ccd7e1e17 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -32,6 +32,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding);
@@ -162,6 +164,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
index e5ea47a6a2a5b..7cd717efc9fba 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
@@ -52,7 +52,7 @@ Status QAttention::CheckInputs(const Tensor* input,
auto& device_prop = GetDeviceProp();
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(),
mask_index, past_tensor,
- nullptr, // extra_add_qk
+ nullptr, // relative_position_bias
parameters,
device_prop.maxThreadsPerBlock));
@@ -198,7 +198,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const {
data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data();
data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims();
data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data());
- data.extra_add_qk = nullptr; // add_qk is not supported in quantized attention
+ data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention
data.workspace = reinterpret_cast(work_space.get());
data.output = reinterpret_cast(output->MutableData());
data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData());
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
index 204c786cc2c5d..8122b2de5916b 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc
@@ -212,7 +212,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), merged_weights_shape, merged_bias_shape,
mask_index,
nullptr, // past
- nullptr, // extra_add_qk
+ nullptr, // relative_position_bias
nullptr, // parameters
device_prop.maxThreadsPerBlock));
diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h
index 5fe62ef127800..5fb31be5fe86f 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h
+++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h
@@ -17,4 +17,4 @@ DefineQOrderedAttentionInput(Scale_QK_Softmax, scale_QKT_softmax, 15),
DefineQOrderedAttentionInput(Scale_Values_Gemm, scale_values_gemm, 16),
DefineQOrderedAttentionInput(Mask_Index, mask_index, 17),
DefineQOrderedAttentionInput(Past, past, 18),
-DefineQOrderedAttentionInput(Extra_Add, extra_add, 19)
+DefineQOrderedAttentionInput(relative_position_bias, relative_position_bias, 19)
diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc
index 756919834aef8..afc9fd9237ed7 100644
--- a/onnxruntime/contrib_ops/rocm/bert/attention.cc
+++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc
@@ -39,7 +39,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
const Tensor* bias = context->Input(2);
const Tensor* mask_index = context->Input(3);
const Tensor* past = context->Input(4);
- const Tensor* extra_add_qk = context->Input(5);
+ const Tensor* relative_position_bias = context->Input(5);
auto& device_prop = GetDeviceProp();
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
@@ -47,7 +47,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
bias->Shape(),
mask_index,
past,
- extra_add_qk,
+ relative_position_bias,
nullptr,
device_prop.maxThreadsPerBlock));
@@ -129,7 +129,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(),
mask_filter_value_,
nullptr == past ? nullptr : past->Data(),
- nullptr == extra_add_qk ? nullptr : extra_add_qk->Data(),
+ nullptr == relative_position_bias ? nullptr : relative_position_bias->Data(),
work_space.get(),
output->MutableData(),
nullptr == present ? nullptr : present->MutableData());
diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
index 954a129be1c65..fa6cce6a64132 100644
--- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
@@ -89,7 +89,7 @@ Status QkvToContext(
bool is_unidirectional,
int past_sequence_length,
const T* past,
- const T* extra_add_qk,
+ const T* relative_position_bias,
T* present,
bool use_persistent_softmax) {
const int all_sequence_length = past_sequence_length + sequence_length;
@@ -158,7 +158,7 @@ Status QkvToContext(
T* persistent_softmax_workspace = scratch1; // replace Q*K' in place if persistent softmax is selected.
ORT_RETURN_IF_ERROR(
ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads,
- mask_index, nullptr, extra_add_qk, scratch1, scratch2,
+ mask_index, nullptr, relative_position_bias, scratch1, scratch2,
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
use_persistent_softmax, persistent_softmax_workspace, mask_filter_value));
} else if (nullptr != mask_index) { // 1d mask index
@@ -166,10 +166,10 @@ Status QkvToContext(
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr;
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads,
- mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional));
+ mask_index, mask_start, relative_position_bias, scratch1, scratch2, is_unidirectional));
} else { // no mask
ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads,
- extra_add_qk, scratch1, scratch2, is_unidirectional));
+ relative_position_bias, scratch1, scratch2, is_unidirectional));
}
// compute P*V (as V*P), and store in scratch3: BxNxSxH
@@ -206,7 +206,7 @@ Status LaunchAttentionKernel(
gsl::span mask_index_dims,
const float mask_filter_value,
const void* past,
- const void* extra_add_qk,
+ const void* relative_position_bias,
void* workspace,
void* output,
void* present) {
@@ -225,7 +225,7 @@ Status LaunchAttentionKernel(
is_unidirectional,
past_sequence_length,
reinterpret_cast(past),
- reinterpret_cast(extra_add_qk),
+ reinterpret_cast(relative_position_bias),
reinterpret_cast<__half*>(present),
use_persistent_softmax);
} else {
@@ -240,7 +240,7 @@ Status LaunchAttentionKernel(
is_unidirectional,
past_sequence_length,
reinterpret_cast(past),
- reinterpret_cast(extra_add_qk),
+ reinterpret_cast(relative_position_bias),
reinterpret_cast(present),
use_persistent_softmax);
}
diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
index 7db692083f5e5..fdc46ce2e7729 100644
--- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
@@ -42,7 +42,7 @@ Status LaunchAttentionKernel(
gsl::span mask_index_dims, // Mask index shape
const float mask_filter_value, // Mask value for filtered out positions
const void* past, // Past state input
- const void* extra_add_qk, // Additional Add
+ const void* relative_position_bias, // Additional Add
void* workspace, // Temporary buffer
void* output, // Output tensor
void* present // Present state output
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 68e3985651123..6b00ac94bc10f 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -243,7 +243,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"T",
OpSchema::Optional)
.Input(5,
- "extra_add",
+ "relative_position_bias",
"additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)",
"T",
OpSchema::Optional)
@@ -313,6 +313,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)",
"M",
OpSchema::Optional)
+ .Input(5,
+ "relative_position_bias",
+ "relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)"
+ " or (1, num_heads, sequence_length, total_sequence_length)",
+ "T",
+ OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",
@@ -668,5 +674,41 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
RestorePaddingTypeAndShapeInference(ctx);
}));
+constexpr const char* GatedRelativePositionBias_ver1_doc = R"DOC(
+ query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2)
+ gate_u, gate_r = torch.sigmoid(
+ self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False)
+ ).chunk(2, dim=-1)
+ gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0
+ rel_pos_bias = gate_u_1 * rel_pos
+)DOC";
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ GatedRelativePositionBias, 1,
+ OpSchema()
+ .SetDoc(GatedRelativePositionBias_ver1_doc)
+ .Attr("num_heads", "Number of attention heads", AttributeProto::INT)
+ .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T")
+ .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T")
+ .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T")
+ .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T")
+ .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T")
+ .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T")
+ .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T")
+ .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ int64_t num_heads = getAttribute(ctx, "num_heads", -1L);
+ if (hasInputShape(ctx, 0)) {
+ auto& query_layer_shape = getInputShape(ctx, 0);
+ TensorShapeProto output_shape;
+ *output_shape.add_dim() = query_layer_shape.dim(0);
+ output_shape.add_dim()->set_dim_value(num_heads);
+ *output_shape.add_dim() = query_layer_shape.dim(1);
+ *output_shape.add_dim() = query_layer_shape.dim(1);
+ updateOutputShape(ctx, 0, output_shape);
+ }
+ }));
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index a511d01fe1624..bd8469909fe7f 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -81,6 +81,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBias);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
@@ -171,6 +172,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index 6111afbd5d817..91e4f5d8ff81a 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -1140,7 +1140,7 @@ where value of each element is the end position, or valid length of actual seque
left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by
the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past
and present state are optional. Present state could appear in output even when past state is not in input.
-Current version does not support past/present, extra_add and qkv_hidden_sizes.
+Current version does not support past/present, relative_position_bias and qkv_hidden_sizes.
TODO: Support them if needed in the future.
)DOC";
@@ -1202,7 +1202,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(18, "past",
"past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).",
"Q", OpSchema::Optional)
- .Input(19, "extra_add",
+ .Input(19, "relative_position_bias",
"additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "S",
OpSchema::Optional)
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "Q")
diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
index b4a92019992b5..c0a75fc50b07e 100644
--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
+++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc
@@ -198,12 +198,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
const TensorShape& bias_shape,
const Tensor*& mask_index,
const Tensor* past,
- const Tensor* extra_add_qk,
+ const Tensor* relative_position_bias,
void* parameters,
const int max_threads_per_block,
const Tensor* past_seq_len) override {
return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past,
- extra_add_qk,
+ relative_position_bias,
parameters,
max_threads_per_block,
past_seq_len);
diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h
index 2490789dd31a2..f12e080adf30a 100644
--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h
+++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h
@@ -145,7 +145,7 @@ struct ProviderHostCPU {
const TensorShape& bias_shape,
const Tensor*& mask_index,
const Tensor* past,
- const Tensor* extra_add_qk,
+ const Tensor* relative_position_bias,
void* parameters,
const int max_threads_per_block,
const Tensor* past_seq_len) = 0;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp
index 63bae80c51a67..af93808248032 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp
@@ -401,15 +401,15 @@ class DmlOperatorAttention : public DmlOperator
void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
{
*isSupported = false;
- // Fall back to CPU if input 'past' and 'extra_add' is present because there is no current use case for this.
+ // Fall back to CPU if input 'past' and 'relative_position_bias' is present because there is no current use case for this.
// and it will make the implementation more complex.
// Also fall back to CPU if output 'present' is present for same reason as above.
if (context->GetInputCount() > 4 || context->GetOutputCount() > 1)
{
return;
}
- // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'extra_add' is present
- // because input 'mask_index', 'past', and 'extra_add' all are optional.
+ // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'relative_position_bias' is present
+ // because input 'mask_index', 'past', and 'relative_position_bias' all are optional.
if (context->IsInputValid(4) || context->IsInputValid(5))
{
return;
diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
index b772fe95d6ecc..30be10ea7e15f 100644
--- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
+++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
@@ -563,12 +563,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const TensorShape& bias_shape,
const Tensor*& mask_index,
const Tensor* past,
- const Tensor* extra_add_qk,
+ const Tensor* relative_position_bias,
void* parameters,
const int max_threads_per_block,
const Tensor* past_seq_len) const {
return g_host_cpu.AttentionBase__CheckInputs(this, input_shape, weights_shape, bias_shape,
- mask_index, past, extra_add_qk, parameters,
+ mask_index, past, relative_position_bias, parameters,
max_threads_per_block, past_seq_len);
}
Tensor* AttentionBase::GetPresent(OpKernelContext* context, const Tensor* past, int batch_size, int head_size,
diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py
index 245ea9322ad61..342d43306e699 100644
--- a/onnxruntime/python/tools/transformers/fusion_attention.py
+++ b/onnxruntime/python/tools/transformers/fusion_attention.py
@@ -337,7 +337,7 @@ def create_attention_node(
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
if self.use_multi_head_attention:
if add_qk_str is not None:
- logger.debug("MultiHeadAttention does not support extra_add_qk: cannot fuse the attention.")
+ logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
return None
attention_inputs = [
diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py
index 96c22b5894c60..d81a062f1060e 100644
--- a/onnxruntime/python/tools/transformers/onnx_model.py
+++ b/onnxruntime/python/tools/transformers/onnx_model.py
@@ -319,7 +319,7 @@ def match_parent_path(
self,
node,
parent_op_types,
- parent_input_index,
+ parent_input_index=None,
output_name_to_node=None,
return_indice=None,
):
@@ -339,7 +339,8 @@ def match_parent_path(
Returns:
parents: a list of matched parent node.
"""
- assert len(parent_input_index) == len(parent_op_types)
+ if parent_input_index is not None:
+ assert len(parent_input_index) == len(parent_op_types)
if output_name_to_node is None:
output_name_to_node = self.output_name_to_node()
@@ -350,16 +351,19 @@ def match_parent_path(
matched_parent = self.match_parent(
current_node,
op_type,
- parent_input_index[i],
+ parent_input_index[i] if parent_input_index is not None else None,
output_name_to_node,
exclude=[],
return_indice=return_indice,
)
if matched_parent is None:
- logger.debug(
- f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}",
- stack_info=True,
- )
+ if parent_input_index is not None:
+ logger.debug(
+ f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}",
+ stack_info=True,
+ )
+ else:
+ logger.debug(f"Failed to match index={i} op_type={op_type}", stack_info=True)
return None
matched_parents.append(matched_parent)
diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py
index dc8f6810914a7..85e510a828990 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py
@@ -172,8 +172,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
add = k_nodes[-2]
matmul = k_nodes[-1]
- extra_add_qk_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
- if extra_add_qk_nodes is None:
+ relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0])
+ if relative_position_bias_nodes is None:
return
if matmul.input[0] == root_input:
@@ -189,7 +189,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
self.hidden_size,
root_input,
attention_last_node.output[0],
- extra_add_qk_nodes[0].input[0],
+ relative_position_bias_nodes[0].input[0],
)
if new_node is None:
return
diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc
index fb1d8fcfe451a..0ea85dfdaba4f 100644
--- a/onnxruntime/test/contrib_ops/attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/attention_op_test.cc
@@ -59,7 +59,7 @@ static void RunAttentionTest(
const bool disable_cuda = false,
const bool disable_rocm = false,
std::vector qkv_sizes = {},
- const std::vector& extra_add_data = {},
+ const std::vector& relative_position_bias_data = {},
int kv_sequence_length = 0,
bool past_present_share_buffer = false,
bool use_scale = false) {
@@ -199,12 +199,12 @@ static void RunAttentionTest(
}
}
- std::vector extra_add_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length};
- if (extra_add_data.size() > 0) {
+ std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length};
+ if (relative_position_bias_data.size() > 0) {
if (use_float16) {
- tester.AddInput("extra_add_qk", extra_add_data_dims, ToFloat16(extra_add_data));
+ tester.AddInput("relative_position_bias", relative_position_bias_data_dims, ToFloat16(relative_position_bias_data));
} else {
- tester.AddInput("extra_add_qk", extra_add_data_dims, extra_add_data);
+ tester.AddInput("relative_position_bias", relative_position_bias_data_dims, relative_position_bias_data);
}
} else {
if (use_float16) {
@@ -264,7 +264,7 @@ static void RunAttentionTest(
const bool disable_cuda = false,
const bool disable_rocm = false,
const std::vector qkv_sizes = {},
- const std::vector& extra_add_data = {},
+ const std::vector& relative_position_bias_data = {},
int kv_sequence_length = 0,
bool past_present_share_buffer = false,
bool use_scale = false) {
@@ -272,13 +272,13 @@ static void RunAttentionTest(
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
past_data, present_data, mask_type, input_hidden_size, max_sequence_length,
- disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data,
+ disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias_data,
kv_sequence_length, past_present_share_buffer, use_scale);
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
past_data, present_data, mask_type, input_hidden_size, max_sequence_length,
- disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data,
+ disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias_data,
kv_sequence_length, past_present_share_buffer, use_scale);
}
@@ -390,7 +390,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) {
0, false, false, disable_rocm, qkv_sizes);
}
-TEST(AttentionTest, AttentionBatch1ExtraAdd) {
+TEST(AttentionTest, AttentionBatch1RelativePositionBias) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
@@ -414,7 +414,7 @@ TEST(AttentionTest, AttentionBatch1ExtraAdd) {
std::vector mask_index_data = {2L};
- std::vector extra_add_qk = {
+ std::vector relative_position_bias = {
0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f};
std::vector output_data = {
@@ -427,10 +427,10 @@ TEST(AttentionTest, AttentionBatch1ExtraAdd) {
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0,
- 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_qk);
+ 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias);
}
-TEST(AttentionTest, AttentionBatch2ExtraAdd) {
+TEST(AttentionTest, AttentionBatch2RelativePositionBias) {
int batch_size = 2;
int sequence_length = 2;
int hidden_size = 4;
@@ -456,7 +456,7 @@ TEST(AttentionTest, AttentionBatch2ExtraAdd) {
std::vector mask_index_data = {2L, 2L};
- std::vector extra_add_qk = {
+ std::vector relative_position_bias = {
0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f,
0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f};
@@ -472,7 +472,7 @@ TEST(AttentionTest, AttentionBatch2ExtraAdd) {
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0,
- 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_qk);
+ 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias);
}
TEST(AttentionTest, AttentionBatch1_Float16) {
@@ -1709,7 +1709,7 @@ TEST(AttentionTest, AttentionWithNormFactor) {
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data,
AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/,
false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, {} /*qkv_sizes*/,
- {} /*extra_add_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/,
+ {} /*relative_position_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/,
true /*use_scale*/);
}
diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc
index cf14a7f9918c6..5257fbdb08809 100644
--- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc
+++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc
@@ -278,7 +278,7 @@ TEST(QOrderedTest, Attention_WithData_ROW_ORDER) {
test_qorder.AddInput("scale_values_gemm", {}, {attn_out_scale}, true);
test_qorder.AddInput("mask_index", {batch_size, sequence_len}, input_mask.data(), input_mask.size());
test_qorder.AddOptionalInputEdge(); // past
- test_qorder.AddOptionalInputEdge(); // extra_add
+ test_qorder.AddOptionalInputEdge(); // relative_position_bias
test_qorder.AddOutput("output", {batch_size, sequence_len, hidden_size}, attn_out_q8.data(), attn_out_q8.size());
diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
index 7722291bee653..ba0299e4f3808 100644
--- a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
+++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc
@@ -10,9 +10,9 @@ namespace onnxruntime {
namespace test {
static void RunRelativePositionBiasTest(
- const std::vector& bias_table, // Shape = [num_buckets, num_heads]
- const std::vector& sequence_length, // Shape = [1]
- const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length]
+ const std::vector& bias_table, // Shape = [num_buckets, num_heads]
+ const std::vector& sequence_length, // Shape = [1]
+ const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length]
int max_distance,
int num_buckets,
int num_heads,
@@ -155,5 +155,264 @@ TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16_No_Bidirectional) {
true);
}
+/***************Following scripts is used to generate test data, for your reference*************
+import torch
+
+batch_size = 2
+num_heads = 2
+seq_len = 3
+head_size = 4
+D = 8
+
+def dim_string_of(tensor):
+ return "{" + ", ".join([str(d) for d in tensor.shape]) + "}"
+
+def value_string_of(tensor):
+ arr = tensor.flatten().numpy()
+ lines = ["f, ".join([str(v) for v in arr[i : min(i+8, arr.size)]]) for i in range(0, arr.size, 8)]
+ return "{\n " + "f,\n ".join(lines) + "f}"
+
+def print_tensor(name, tensor):
+ print(f"const std::vector {name}_dim = {dim_string_of(tensor)};")
+ print(f"const std::vector {name} = {value_string_of(tensor)};")
+
+torch.manual_seed(0)
+query_layer = torch.rand(batch_size, seq_len, num_heads * head_size)
+query_bias = torch.rand(num_heads * head_size)
+rel_pos = torch.rand(1, num_heads, seq_len, seq_len)
+weight = torch.rand(head_size, D)
+bias = torch.rand(D)
+eco_a = torch.rand(1, num_heads, 1, 1)
+
+qw = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2)
+gate_u,gate_r = torch.sigmoid(
+ (torch.matmul(qw, weight) + bias).view(batch_size, num_heads, seq_len,2, D//2).sum(-1, keepdim=False)
+ ).chunk(2, dim=-1)
+gate_u_1 = gate_u * (gate_r * eco_a - 1.0) + 2.0
+output = gate_u_1 * rel_pos
+
+# output for test case
+print(f"const int batch_size = {batch_size};")
+print(f"const int num_heads = {num_heads};")
+print(f"const int seq_len = {seq_len};")
+print(f"const int head_size = {head_size};")
+print(f"const int D = {D};")
+
+print_tensor("query_layer", query_layer)
+print_tensor("query_bias", query_bias)
+print_tensor("rel_pos", rel_pos)
+print_tensor("weight", weight)
+print_tensor("bias", bias)
+print_tensor("eco_a", eco_a)
+print_tensor("output", output)
+****************/
+
+// .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T")
+// .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T")
+// .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T")
+// .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T")
+// .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T")
+// .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T")
+// .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T")
+static void RunGatedRelativePositionBiasTest(
+ const std::vector& query_layer,
+ const std::vector& query_bias,
+ const std::vector& rel_pos,
+ const std::vector& weight,
+ const std::vector& bias,
+ const std::vector& eco_a,
+ const std::vector& output,
+ int batch_size,
+ int seq_len,
+ int num_heads,
+ int head_size,
+ int D,
+ bool use_float16 = false) {
+ int min_cuda_architecture = use_float16 ? 530 : 0;
+
+ bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
+ if (enable_cuda) {
+ OpTester tester("GatedRelativePositionBias", 1, onnxruntime::kMSDomain);
+ tester.AddAttribute("num_heads", static_cast(num_heads));
+
+ std::vector query_layer_dims = {batch_size, seq_len, num_heads * head_size};
+ std::vector query_bias_dims = {num_heads * head_size};
+ std::vector rel_pos_dims = {1, num_heads, seq_len, seq_len};
+ std::vector weight_dims = {head_size, D};
+ std::vector bias_dims = {D};
+ std::vector eco_a_dims = {1, num_heads, 1, 1};
+ std::vector output_dims = {batch_size, num_heads, seq_len, seq_len};
+
+ if (use_float16) {
+ tester.AddInput("query_layer", query_layer_dims, ToFloat16(query_layer));
+ tester.AddInput("query_bias", query_bias_dims, ToFloat16(query_bias));
+ tester.AddInput("rel_pos", rel_pos_dims, ToFloat16(rel_pos));
+ tester.AddInput("weight", weight_dims, ToFloat16(weight));
+ tester.AddInput("bias", bias_dims, ToFloat16(bias));
+ tester.AddInput("eco_a", eco_a_dims, ToFloat16(eco_a));
+ tester.AddOutput("output", output_dims, ToFloat16(output));
+ } else {
+ tester.AddInput("query_layer", query_layer_dims, query_layer);
+ tester.AddInput("query_bias", query_bias_dims, query_bias);
+ tester.AddInput("rel_pos", rel_pos_dims, rel_pos);
+ tester.AddInput("weight", weight_dims, weight);
+ tester.AddInput("bias", bias_dims, bias);
+ tester.AddInput("eco_a", eco_a_dims, eco_a);
+ tester.AddOutput("output", output_dims, output);
+ }
+
+ std::vector> execution_providers;
+ execution_providers.push_back(DefaultCudaExecutionProvider());
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ }
+}
+
+TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) {
+ constexpr int batch_size = 1;
+ constexpr int num_heads = 2;
+ constexpr int seq_len = 3;
+ constexpr int head_size = 4;
+ constexpr int D = 8;
+ const std::vector query_layer_dim = {1, 3, 8};
+ const std::vector query_layer = {
+ 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f,
+ 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f,
+ 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f};
+ const std::vector query_bias_dim = {8};
+ const std::vector query_bias = {
+ 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f};
+ const std::vector rel_pos_dim = {1, 2, 3, 3};
+ const std::vector rel_pos = {
+ 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f,
+ 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f,
+ 0.27825248f, 0.4819588f};
+ const std::vector weight_dim = {4, 8};
+ const std::vector weight = {
+ 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f, 0.593172f, 0.112347245f,
+ 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, 0.774486f, 0.43689132f,
+ 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, 0.69650495f, 0.9142747f,
+ 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f, 0.034022927f, 0.94424623f};
+ const std::vector bias_dim = {8};
+ const std::vector bias = {
+ 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f, 0.6922781f, 0.20384824f};
+ const std::vector eco_a_dim = {1, 2, 1, 1};
+ const std::vector eco_a = {
+ 0.68329567f, 0.75285405f};
+ const std::vector output_dim = {1, 2, 3, 3};
+ const std::vector output = {
+ 0.29608122f, 0.45416728f, 0.25361493f, 0.053390637f, 0.3503264f, 1.5650483f, 1.2171557f, 1.2495192f,
+ 0.88587445f, 0.42708054f, 1.0246648f, 0.05810945f, 0.2430356f, 0.4244021f, 1.428723f, 1.3902748f,
+ 0.48772895f, 0.84479123f};
+
+ RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output,
+ batch_size, seq_len, num_heads, head_size, D, true);
+}
+
+TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) {
+ constexpr int batch_size = 2;
+ constexpr int num_heads = 2;
+ constexpr int seq_len = 3;
+ constexpr int head_size = 4;
+ constexpr int D = 8;
+ const std::vector query_layer_dim = {2, 3, 8};
+ const std::vector query_layer = {
+ 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f,
+ 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f,
+ 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f,
+ 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f,
+ 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f,
+ 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f};
+ const std::vector query_bias_dim = {8};
+ const std::vector