From 8b2e392d43e561b8992b99e0190b614c532bad43 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 29 Jan 2023 02:59:48 +0000 Subject: [PATCH 01/16] init --- .../python/tools/transformers/onnx_model.py | 18 ++- .../tools/transformers/onnx_model_tulr.py | 125 ++++++++++++++++++ .../python/tools/transformers/optimizer.py | 2 + 3 files changed, 138 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/onnx_model_tulr.py diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 4827facd78100..2759b72c159a0 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -319,7 +319,7 @@ def match_parent_path( self, node, parent_op_types, - parent_input_index, + parent_input_index=None, output_name_to_node=None, return_indice=None, ): @@ -339,7 +339,8 @@ def match_parent_path( Returns: parents: a list of matched parent node. """ - assert len(parent_input_index) == len(parent_op_types) + if parent_input_index is not None: + assert len(parent_input_index) == len(parent_op_types) if output_name_to_node is None: output_name_to_node = self.output_name_to_node() @@ -350,16 +351,19 @@ def match_parent_path( matched_parent = self.match_parent( current_node, op_type, - parent_input_index[i], + parent_input_index[i] if parent_input_index is not None else None, output_name_to_node, exclude=[], return_indice=return_indice, ) if matched_parent is None: - logger.debug( - f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", - stack_info=True, - ) + if parent_input_index is not None: + logger.debug( + f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", + stack_info=True, + ) + else: + logger.debug(f"Failed to match index={i} op_type={op_type}", stack_info=True) return None matched_parents.append(matched_parent) diff --git a/onnxruntime/python/tools/transformers/onnx_model_tulr.py b/onnxruntime/python/tools/transformers/onnx_model_tulr.py new file mode 100644 index 0000000000000..93b6c9cd56976 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_tulr.py @@ -0,0 +1,125 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Union + +from fusion_attention import AttentionMask, FusionAttention +from fusion_utils import NumpyHelper +from onnx import NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel +from onnx_model_bert import BertOnnxModel +from fusion_base import Fusion +import numpy as np + +logger = logging.getLogger(__name__) + +#python optimizer.py --input /home/wy/Turing/tulr/model.onnx --output /home/wy/Turing/tulr/opt/model.onnx --model_type tulr --num_heads 16 --hidden_size 1024 --use_external_data_format + +class FusionTulrAttention(FusionAttention): + """ + Turing + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def create_attention_node( + self, + mask_index: str, + matmul: NodeProto, + add: NodeProto, + num_heads: int, + hidden_size: int, + input: str, + output: str, + add_qk_str: str, + ) -> Union[NodeProto, None]: + return + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + return + +# Attr("max_distance", "Max distance", AttributeProto::INT) +# Attr("is_bidirectional", "Default value is 0.", AttributeProto::INT, static_cast(0)) +# Input(0, "bias_table", "2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)", "T") +# Input(1, "query_length", "The length of query. Self Attention requires query_length = key_length", "U") +# Input(2, "key_length", "The length of key.", "U") +# Output(0, "output", "4D output tensor with shape (1, num_heads, sequence_length, sequence_length)", "T") +class FusionRelativePositionBiasBlock(Fusion): + def __init__(self, model: OnnxModel, max_distance: int, is_bidirectional: bool): + super().__init__(model, "RelativePositionBias", "GatherElements") + self.max_distance = 128 + self.is_bidirectional = 1 + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + stem_nodes = self.model.match_parent_path( + node, + ["Expand", "Where", "Equal", "Concat", "Unsqueeze", "Gather", "Shape", "Sub", "Unsqueeze", "Expand", "Unsqueeze", "Range"], + ) + if stem_nodes is None: + return + + expand = stem_nodes[0] + range = stem_nodes[-1] + rpb_nodes = self.model.match_parent_path( + expand, + ["Unsqueeze", "Unsqueeze", "Gemm"] + ) + if rpb_nodes is None: + return + + gemm = rpb_nodes[-1] + + self.nodes_to_remove.extend(stem_nodes) + self.nodes_to_remove.extend(rpb_nodes) + + table_weight = self.model.get_initializer(gemm.input[0]) + table_weight_np = NumpyHelper.to_array(table_weight) + table_weight_np_t = table_weight_np.transpose() + print(np.shape(table_weight_np_t)[0]) + bias_table = helper.make_tensor( + name="bias_table_weight", + data_type=TensorProto.FLOAT, + dims=[np.shape(table_weight_np_t)[0], np.shape(table_weight_np_t)[1]], + vals=table_weight_np_t.flatten().tolist(), + ) + self.model.add_initializer(bias_table, self.this_graph_name) + inputs = [bias_table.name, range.input[1], range.input[1]] + outputs = [node.output[0]] + rpb_node = helper.make_node( + "RelativePositionBias", + inputs=inputs, + outputs=outputs, + name=self.model.create_node_name("RelativePositionBias", name_prefix="RPB"), + ) + rpb_node.domain = "com.microsoft" + rpb_node.attribute.extend([helper.make_attribute("max_distance", self.max_distance)]) + rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", self.is_bidirectional)]) + + self.nodes_to_add.append(rpb_node) + self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name + + + +class TulrOnnxModel(BertOnnxModel): + def __init__(self, model, num_heads, hidden_size): + super().__init__(model, num_heads, hidden_size) + self.attention_mask = AttentionMask(self) + self.attention_fusion = FusionTulrAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + self.rpb_fusion = FusionRelativePositionBiasBlock(self, 32, True) + + def fuse_attention(self): + self.attention_fusion.apply() + + def postprocess(self): + self.rpb_fusion.apply() + self.clean_graph() + self.prune_graph() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 56076eedda78a..3c3df6d251971 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -31,6 +31,7 @@ from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_tnlr import TnlrOnnxModel +from onnx_model_tulr import TulrOnnxModel from onnx_model_unet import UnetOnnxModel logger = logging.getLogger(__name__) @@ -48,6 +49,7 @@ 0, ), # might add a class for GPT2OnnxModel for TF later. "tnlr": (TnlrOnnxModel, "pytorch", 1), + "tulr": (TulrOnnxModel, "pytorch", 1), "unet": (UnetOnnxModel, "pytorch", 1), } From c867e3909736da053c0c4c6ebf8d14fcc7077ae3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 29 Jan 2023 08:39:35 +0000 Subject: [PATCH 02/16] a --- .../cuda/bert/relative_attn_bias_impl.cu | 34 ++++++++++--------- .../tools/transformers/onnx_model_tulr.py | 6 ++-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index e333152cb5bcf..5a1373e3dc1fb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -36,7 +36,7 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, const bool is_bidirectional, const int max_distance) { const int head_id = blockIdx.x; - for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { + for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) { int row_id = seq_id / seq_len; int col_id = seq_id % seq_len; @@ -101,22 +101,24 @@ Status LaunchRelPosAttnBiasKernel( is_bidirectional, max_distance); return CUDA_CALL(cudaGetLastError()); - } else if (seq_len >= 128 && seq_len <= 384) { - dim3 grid(num_heads, seq_len); - dim3 block(seq_len); - buildRelativeAttentionBias<<>>(output, - bias_table, - num_heads, - seq_len, - num_bucket, - is_bidirectional, - max_distance); - return CUDA_CALL(cudaGetLastError()); + // } else if (seq_len >= 128 && seq_len <= 384) { + // dim3 grid(num_heads, seq_len); + // dim3 block(seq_len); + // buildRelativeAttentionBias<<>>(output, + // bias_table, + // num_heads, + // seq_len, + // num_bucket, + // is_bidirectional, + // max_distance); + // return CUDA_CALL(cudaGetLastError()); } else { - int blockSize = max_threads_per_block; - const int grid_y_Size = (squared_sq_len + blockSize - 1) / blockSize; - dim3 grid(num_heads, grid_y_Size); - dim3 block(blockSize); + // int blockSize = max_threads_per_block; + // const int grid_y_Size = (squared_sq_len + blockSize - 1) / blockSize; + // dim3 grid(num_heads, grid_y_Size); + // dim3 block(blockSize); + dim3 grid(num_heads); + dim3 block(256); buildRelativeAttentionBias<<>>(output, bias_table, num_heads, diff --git a/onnxruntime/python/tools/transformers/onnx_model_tulr.py b/onnxruntime/python/tools/transformers/onnx_model_tulr.py index 93b6c9cd56976..5ad1e0513d029 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tulr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tulr.py @@ -83,13 +83,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): table_weight = self.model.get_initializer(gemm.input[0]) table_weight_np = NumpyHelper.to_array(table_weight) - table_weight_np_t = table_weight_np.transpose() - print(np.shape(table_weight_np_t)[0]) bias_table = helper.make_tensor( name="bias_table_weight", data_type=TensorProto.FLOAT, - dims=[np.shape(table_weight_np_t)[0], np.shape(table_weight_np_t)[1]], - vals=table_weight_np_t.flatten().tolist(), + dims=[np.shape(table_weight_np)[1], np.shape(table_weight_np)[0]], + vals=table_weight_np.flatten().tolist(), ) self.model.add_initializer(bias_table, self.this_graph_name) inputs = [bias_table.name, range.input[1], range.input[1]] From de7221250314ff046024c99c09373df3c66cecf2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 30 Jan 2023 20:24:48 +0000 Subject: [PATCH 03/16] b --- .../cuda/bert/relative_attn_bias_impl.cu | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index 5a1373e3dc1fb..5cf31bcbf1141 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -36,7 +36,7 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, const bool is_bidirectional, const int max_distance) { const int head_id = blockIdx.x; - for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) { + for (int seq_id = gridDim.y * blockIdx.y + threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { int row_id = seq_id / seq_len; int col_id = seq_id % seq_len; @@ -101,24 +101,22 @@ Status LaunchRelPosAttnBiasKernel( is_bidirectional, max_distance); return CUDA_CALL(cudaGetLastError()); - // } else if (seq_len >= 128 && seq_len <= 384) { - // dim3 grid(num_heads, seq_len); - // dim3 block(seq_len); - // buildRelativeAttentionBias<<>>(output, - // bias_table, - // num_heads, - // seq_len, - // num_bucket, - // is_bidirectional, - // max_distance); - // return CUDA_CALL(cudaGetLastError()); + } else if (seq_len >= 128 && seq_len <= 384) { + dim3 grid(num_heads, seq_len); + dim3 block(seq_len); + buildRelativeAttentionBias<<>>(output, + bias_table, + num_heads, + seq_len, + num_bucket, + is_bidirectional, + max_distance); + return CUDA_CALL(cudaGetLastError()); } else { - // int blockSize = max_threads_per_block; - // const int grid_y_Size = (squared_sq_len + blockSize - 1) / blockSize; - // dim3 grid(num_heads, grid_y_Size); - // dim3 block(blockSize); - dim3 grid(num_heads); - dim3 block(256); + int blockSize = max_threads_per_block; + const int grid_y_Size = (squared_sq_len + blockSize - 1) / blockSize; + dim3 grid(num_heads, grid_y_Size); + dim3 block(blockSize); buildRelativeAttentionBias<<>>(output, bias_table, num_heads, From 31729bd66d1706ec1b2cc6dc0751ad150bc44dca Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 30 Jan 2023 23:05:32 +0000 Subject: [PATCH 04/16] multihead attn fusion --- .../tools/transformers/onnx_model_tulr.py | 259 +++++++++++++++++- 1 file changed, 254 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/transformers/onnx_model_tulr.py b/onnxruntime/python/tools/transformers/onnx_model_tulr.py index 5ad1e0513d029..3f4a075de8cee 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tulr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tulr.py @@ -19,7 +19,7 @@ class FusionTulrAttention(FusionAttention): """ - Turing + Fuse TULR Attention subgraph into one Attention node. """ def __init__( @@ -34,18 +34,267 @@ def __init__( def create_attention_node( self, mask_index: str, - matmul: NodeProto, - add: NodeProto, + q_matmul: NodeProto, + k_matmul: NodeProto, + v_matmul: NodeProto, + q_add: NodeProto, + k_add: NodeProto, + v_add: NodeProto, num_heads: int, hidden_size: int, input: str, output: str, add_qk_str: str, ) -> Union[NodeProto, None]: - return + """Create an Attention node. + + Args: + mask_index (str): mask input + q_matmul (NodeProto): MatMul node in fully connection for Q + k_matmul (NodeProto): MatMul node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V + q_add (NodeProto): Add bias node in fully connection for Q + k_add (NodeProto): Add bias node in fully connection for K + v_add (NodeProto): Add bias node in fully connection for V + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. + input (str): input name + output (str): output name + + Returns: + Union[NodeProto, None]: the node created or None if failed. + """ + assert num_heads > 0 + + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + return None + + q_weight = self.model.get_initializer(q_matmul.input[1]) + k_weight = self.model.get_initializer(k_matmul.input[1]) + v_weight = self.model.get_initializer(v_matmul.input[1]) + q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0]) + #k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0]) + v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0]) + + if q_weight is None: + print( + f"{q_matmul.input[1]} is not an initializer. " + "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion" + ) + return None + if not (k_weight and v_weight and q_bias): + return None + + qw = NumpyHelper.to_array(q_weight) + kw = NumpyHelper.to_array(k_weight) + vw = NumpyHelper.to_array(v_weight) + + # assert q and k have same shape as expected + assert qw.shape == kw.shape + + qw_in_size = qw.shape[0] + kw_in_size = kw.shape[0] + vw_in_size = vw.shape[0] + + assert qw_in_size == kw_in_size == vw_in_size + + if hidden_size > 0 and hidden_size != qw_in_size: + logger.warning( + f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). " + "Please provide a correct input hidden size or pass in 0" + ) + + is_qkv_diff_dims = False + if qw.shape != vw.shape: + is_qkv_diff_dims = True + + # All the matrices can have the same shape or q, k matrics can have the same shape with v being different + # For 2d weights, the shapes would be [in_size, out_size]. + # For 3d weights, shape would be [in_size, a, b] where a*b = out_size + qw_out_size = np.prod(qw.shape[1:]) + kw_out_size = np.prod(kw.shape[1:]) + vw_out_size = np.prod(vw.shape[1:]) + + qkv_weight_dim = 0 + if is_qkv_diff_dims: + qkv_weight = np.concatenate((qw, kw, vw), axis=1) + qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size + else: + qkv_weight = np.stack((qw, kw, vw), axis=1) + qkv_weight_dim = 3 * qw_out_size + + qb = NumpyHelper.to_array(q_bias) + #kb = NumpyHelper.to_array(k_bias) + kb = np.zeros_like(qb) + vb = NumpyHelper.to_array(v_bias) + + q_bias_shape = np.prod(qb.shape) + k_bias_shape = q_bias_shape + v_bias_shape = np.prod(vb.shape) + + assert q_bias_shape == k_bias_shape == qw_out_size + assert v_bias_shape == vw_out_size + + qkv_bias_dim = 0 + if is_qkv_diff_dims: + qkv_bias = np.concatenate((qb, kb, vb), axis=0) + qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape + else: + qkv_bias = np.stack((qb, kb, vb), axis=0) + qkv_bias_dim = 3 * q_bias_shape + + attention_node_name = self.model.create_node_name("Attention") + + bias = helper.make_tensor( + name=attention_node_name + "_qkv_bias", + data_type=TensorProto.FLOAT, + dims=[qkv_bias_dim], + vals=qkv_bias.flatten().tolist(), + ) + if q_bias.data_type == 10: + bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) + self.model.add_initializer(bias, self.this_graph_name) + + attention_inputs = [ + q_matmul.output[0], + k_matmul.output[0], + v_matmul.output[0], + attention_node_name + "_qkv_bias", + ] + if mask_index is not None: + attention_inputs.append(mask_index) + if add_qk_str is not None: + attention_inputs.append(add_qk_str) + + attention_node = helper.make_node( + "MultiHeadAttention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) + + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + if is_qkv_diff_dims: + attention_node.attribute.extend( + [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])] + ) + + if self.mask_filter_value is not None: + attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))]) + + return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - return + # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm + # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern + start_node = normalize_node + if normalize_node.op_type != "SkipLayerNormalization": + return + + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + ) + if qkv_nodes is not None: + (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes + else: + return + + other_inputs = [] + for i, input in enumerate(start_node.input): + if input not in output_name_to_node: + continue + + if input == qkv_nodes[0].output[0]: + continue + other_inputs.append(input) + if len(other_inputs) != 1: + return + + root_input = other_inputs[0] + + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Add", "MatMul"], + ) + if v_nodes is None: + return + (_, _, add_v, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Add", "Div", "MatMul"]) + if qk_nodes is None: + return + (_, add_qk, _, _, matmul_qk) = qk_nodes + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 1], + ) + if q_nodes is None: + return + add_q = q_nodes[-2] + matmul_q = q_nodes[-1] + + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + if k_nodes is None: + return + + add_k = None + matmul_k = k_nodes[-1] + + extra_add_qk_nodes = self.model.match_parent_path(add_qk, ["Mul"]) + if extra_add_qk_nodes is None: + return + + mask_nodes = self.model.match_parent_path( + add_qk, + ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], + ) + if len(mask_nodes) > 1 and mask_nodes[1].op_type == "Mul": + _, mul_val = self.model.get_constant_input(mask_nodes[0]) + if mul_val != -10000: + self.mask_filter_value = mul_val + + if matmul_q.input[0] == root_input: + mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately + new_node = self.create_attention_node( + mask_index, + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + self.num_heads, + self.hidden_size, + root_input, + reshape_qkv.output[0], + add_qk.input[1], + ) + if new_node is None: + return + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + self.nodes_to_remove.extend(k_nodes[:-1]) + self.nodes_to_remove.extend(v_nodes[:-1]) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + self.nodes_to_remove.extend(mask_nodes) + self.prune_graph = True # Attr("max_distance", "Max distance", AttributeProto::INT) # Attr("is_bidirectional", "Default value is 0.", AttributeProto::INT, static_cast(0)) From ac4afbfd293b781bb096014e34c72978a6e50e31 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 30 Jan 2023 23:31:46 +0000 Subject: [PATCH 05/16] c --- onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index 5cf31bcbf1141..308ef770fd857 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -36,7 +36,7 @@ __global__ void buildRelativeAttentionBias(T* relative_attention_bias, const bool is_bidirectional, const int max_distance) { const int head_id = blockIdx.x; - for (int seq_id = gridDim.y * blockIdx.y + threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { + for (int seq_id = blockDim.x * blockIdx.y + threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x * gridDim.y) { int row_id = seq_id / seq_len; int col_id = seq_id % seq_len; From 41692d8b36db665c3a0b6d03f722a3eb50522abb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 31 Jan 2023 04:06:06 +0000 Subject: [PATCH 06/16] d --- .../cpu/bert/multihead_attention_helper.h | 33 +++++++++ .../contrib_ops/cuda/bert/attention_impl.cu | 1 + .../contrib_ops/cuda/bert/attention_softmax.h | 10 +-- .../cuda/bert/multihead_attention.cc | 7 +- .../core/graph/contrib_ops/bert_defs.cc | 6 ++ .../tools/transformers/onnx_model_tulr.py | 72 ++++++++++++++----- 6 files changed, 104 insertions(+), 25 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index ce109a83720b9..e0b84849f9bcc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -17,6 +17,7 @@ Status CheckInputs(const T* query, const T* value, const T* bias, const T* key_padding_mask, + const T* relative_position_bias, void* parameters, int num_heads, float mask_filter_value, @@ -26,6 +27,7 @@ Status CheckInputs(const T* query, // value (V) : (B, L, D_v) // bias (Q/K/V) : (D + D + D_v) // key_padding_mask (K/V) : (B, L) or (L) + // relative_position_bias : (B, 1, S, L) or (1, 1, S, L) const auto& query_dims = query->Shape().GetDims(); if (query_dims.size() != 3) { @@ -89,6 +91,37 @@ Status CheckInputs(const T* query, } v_hidden_size = value_dims[2]; + if (relative_position_bias != nullptr) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + + if (relative_position_bias_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); + } + std::cout << "relative_position_bias_dims" << relative_position_bias_dims[0] << " " << relative_position_bias_dims[1] << " " << relative_position_bias_dims[2] << " " << relative_position_bias_dims[3] << std::endl; + if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", + relative_position_bias_dims[0]); + } + if (relative_position_bias_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); + } + if (relative_position_bias_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); + } + if (relative_position_bias_dims[3] != kv_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); + } + } + if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); output_parameters->batch_size = static_cast(batch_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 187f1bb37edc5..15da6db4de583 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -637,6 +637,7 @@ Status QkvToContext( bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. + DUMP_ATTENTION_D("extra_add_qk", data.extra_add_qk, batch_size, num_heads, sequence_length, total_sequence_length); ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.extra_add_qk, scratch1, scratch2, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index 953a45e15b32e..35175e5ebe037 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -377,11 +377,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, float thread_data = -CUDART_INF_F; if (threadIdx.x < all_sequence_length) { - if (add_before_softmax == nullptr) { - thread_data = float(input[index]) * rsqrt_head_size; - } else { - thread_data = float(input[index] + add_before_softmax[index]) * rsqrt_head_size; - } + thread_data = float(input[index]) * rsqrt_head_size; const int sequence_index = blockIdx.x % sequence_length; if (is_unidirectional) { @@ -412,6 +408,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, thread_data = -CUDART_INF_F; } } + + if (add_before_softmax != nullptr) { + thread_data += float(add_before_softmax[index]); + } } if (skip_softmax) { diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index c7e5d34e1691b..66aac47d9e3f6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -62,6 +62,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* bias = context->Input(3); const Tensor* key_padding_mask = context->Input(4); + const Tensor* relative_position_bias = context->Input(5); auto& device_prop = GetDeviceProp(); AttentionParameters parameters; @@ -70,6 +71,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { value, bias, key_padding_mask, + relative_position_bias, ¶meters, num_heads_, mask_filter_value_, @@ -94,6 +96,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && + nullptr == relative_position_bias && parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); @@ -111,6 +114,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = !disable_fused_runner_ && fused_cross_attention_kernel == nullptr && + nullptr == relative_position_bias && (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && @@ -141,6 +145,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !disable_memory_efficient_attention_ && is_long_sequence && nullptr == key_padding_mask && // TODO: support 1D mask + nullptr == relative_position_bias && has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; @@ -169,7 +174,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past = nullptr; - data.extra_add_qk = nullptr; + data.extra_add_qk = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index b4ad4d64e7ddb..53128ff1f4f8c 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -302,6 +302,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)", "M", OpSchema::Optional) + .Input(5, + "relative_position_bias", + "relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)" + " or (1, num_heads, sequence_length, total_sequence_length)", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, v_hidden_size)", diff --git a/onnxruntime/python/tools/transformers/onnx_model_tulr.py b/onnxruntime/python/tools/transformers/onnx_model_tulr.py index 3f4a075de8cee..c7d84354f22f4 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tulr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tulr.py @@ -146,6 +146,20 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") + use_multi_head_attention = False + if not use_multi_head_attention: + weight = helper.make_tensor( + name=attention_node_name + "_qkv_weight", + data_type=TensorProto.FLOAT, + dims=[qw_in_size, qkv_weight_dim], + vals=qkv_weight.flatten().tolist(), + ) + + # Sometimes weights and bias are stored in fp16 + if q_weight.data_type == 10: + weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) + self.model.add_initializer(weight, self.this_graph_name) + bias = helper.make_tensor( name=attention_node_name + "_qkv_bias", data_type=TensorProto.FLOAT, @@ -156,24 +170,45 @@ def create_attention_node( bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) self.model.add_initializer(bias, self.this_graph_name) - attention_inputs = [ - q_matmul.output[0], - k_matmul.output[0], - v_matmul.output[0], - attention_node_name + "_qkv_bias", - ] - if mask_index is not None: - attention_inputs.append(mask_index) - if add_qk_str is not None: - attention_inputs.append(add_qk_str) - - attention_node = helper.make_node( - "MultiHeadAttention", - inputs=attention_inputs, - outputs=[output], - name=attention_node_name, - ) - + if use_multi_head_attention: + attention_inputs = [ + q_matmul.output[0], + k_matmul.output[0], + v_matmul.output[0], + attention_node_name + "_qkv_bias", + ] + if mask_index is not None: + attention_inputs.append(mask_index) + if add_qk_str is not None: + attention_inputs.append(add_qk_str) + + attention_node = helper.make_node( + "MultiHeadAttention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) + else: + attention_inputs = [ + input, + attention_node_name + "_qkv_weight", + attention_node_name + "_qkv_bias", + ] + if mask_index is not None: + attention_inputs.append(mask_index) + else: + attention_inputs.append("") + + if add_qk_str is not None: + attention_inputs.append("") # no past + attention_inputs.append(add_qk_str) + + attention_node = helper.make_node( + "Attention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) @@ -355,7 +390,6 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name - class TulrOnnxModel(BertOnnxModel): def __init__(self, model, num_heads, hidden_size): super().__init__(model, num_heads, hidden_size) From 436faf905393f7fb752a8470bcf9348bc3843a16 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 31 Jan 2023 04:08:52 +0000 Subject: [PATCH 07/16] e --- onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h | 1 - onnxruntime/python/tools/transformers/onnx_model_tulr.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index e0b84849f9bcc..040ce78363fcd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -99,7 +99,6 @@ Status CheckInputs(const T* query, "Input 'relative_position_bias' is expected to have 4 dimensions, got ", relative_position_bias_dims.size()); } - std::cout << "relative_position_bias_dims" << relative_position_bias_dims[0] << " " << relative_position_bias_dims[1] << " " << relative_position_bias_dims[2] << " " << relative_position_bias_dims[3] << std::endl; if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", diff --git a/onnxruntime/python/tools/transformers/onnx_model_tulr.py b/onnxruntime/python/tools/transformers/onnx_model_tulr.py index c7d84354f22f4..bede363a6c0c5 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tulr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tulr.py @@ -146,7 +146,7 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") - use_multi_head_attention = False + use_multi_head_attention = True if not use_multi_head_attention: weight = helper.make_tensor( name=attention_node_name + "_qkv_weight", From ea96f768ca80809c7867bec1eb78d5c4520f3061 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 31 Jan 2023 19:59:39 +0000 Subject: [PATCH 08/16] name change --- docs/ContribOperators.md | 186 +++++++++--------- docs/OperatorKernels.md | 8 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 6 +- .../contrib_ops/cpu/bert/attention_base.cc | 48 ++--- .../contrib_ops/cpu/bert/attention_base.h | 4 +- .../contrib_ops/cpu/bert/attention_cpu_base.h | 38 ++-- .../cpu/quantization/attention_quant.cc | 2 +- .../contrib_ops/cuda/bert/attention.cc | 12 +- .../contrib_ops/cuda/bert/attention_impl.cu | 8 +- .../contrib_ops/cuda/bert/attention_impl.h | 2 +- .../cuda/bert/multihead_attention.cc | 2 +- .../quantization/attention_quantization.cc | 4 +- .../qordered_ops/qordered_attention.cc | 2 +- .../qordered_attention_input_enum.h | 2 +- .../contrib_ops/rocm/bert/attention.cc | 6 +- .../contrib_ops/rocm/bert/attention_impl.cu | 14 +- .../contrib_ops/rocm/bert/attention_impl.h | 2 +- .../core/graph/contrib_ops/bert_defs.cc | 2 +- .../graph/contrib_ops/quantization_defs.cc | 4 +- .../core/providers/cpu/cpu_provider_shared.cc | 4 +- .../core/providers/cpu/cpu_provider_shared.h | 2 +- .../src/Operators/DmlOperatorAttention.cpp | 6 +- .../provider_bridge_provider.cc | 4 +- .../tools/transformers/fusion_attention.py | 2 +- .../tools/transformers/onnx_model_tnlr.py | 6 +- .../tools/transformers/onnx_model_tulr.py | 4 +- .../test/contrib_ops/attention_op_test.cc | 30 +-- .../contrib_ops/qordered_attention_test.cc | 2 +- 28 files changed, 205 insertions(+), 207 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1e6d46963cd21..a8d2079376ef4 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -94,24 +94,24 @@ Do not modify directly.* ### **com.microsoft.Attention** Multi-Head Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). - + The weights for input projection of Q, K and V are merged. The data is stacked on the second dimension. Its shape is (input_hidden_size, hidden_size + hidden_size + v_hidden_size). Here hidden_size is the hidden dimension of Q and K, and v_hidden_size is that of V. - + The mask_index is optional. Besides raw attention mask with shape (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) with value 0 for masked and 1 otherwise, we support other two formats: When input has right-side padding, mask_index is one dimension with shape (batch_size), where value is actual sequence length excluding padding. When input has left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. - + When unidirectional is 1, each token only attends to previous tokens. - + Both past and present state are optional. They shall be used together, and not allowed to use only one of them. The qkv_hidden_sizes is required only when K and V have different hidden sizes. - + When there is past state, hidden dimension for Q, K and V shall be the same. - + The total_sequence_length is past_sequence_length + kv_sequence_length. Here kv_sequence_length is the length of K or V. For self attention, kv_sequence_length equals to sequence_length (sequence length of Q). For cross attention, query and key might have different lengths. @@ -150,7 +150,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
-
extra_add (optional) : T
+
relative_position_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
@@ -179,133 +179,133 @@ This version of the operator has been available since version 1 of the 'com.micr Computes an one-layer RNN where its RNN Cell is an AttentionWrapper wrapped a LSTM Cell. The RNN layer contains following basic component: LSTM Cell, Bahdanau Attention Mechanism, AttentionWrapp. - + Activation functions: - + Relu(x) - max(0, x) - + Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) - + Sigmoid(x) - 1/(1 + e^{-x}) - + (NOTE: Below are optional) - + Affine(x) - alpha*x + beta - + LeakyRelu(x) - x if x >= 0 else alpha * x - + ThresholdedRelu(x) - x if x >= alpha else 0 - + ScaledTanh(x) - alpha*Tanh(beta*x) - + HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) - + Elu(x) - x if x >= 0 else alpha*(e^x - 1) - + Softsign(x) - x/(1 + |x|) - + Softplus(x) - log(1 + e^x) - + Softmax(x) - exp(x) / sum(exp(x)) - + Bahdanau Attention Mechanism: `M` - Memory tensor. - + `VALUES` - masked Memory by its real sequence length. - + `MW` - Memory layer weight. - + `KEYS` - Processed memory tensor by the memory layer. KEYS = M * MW - + `Query` - Query tensor, normally at specific time step in sequence. - + `QW` - Query layer weight in the attention mechanism - + `PQ` - processed query, = `Query` * `QW` - + `V' - attention vector - + `ALIGN` - calculated alignment based on Query and KEYS ALIGN = softmax(reduce_sum(`V` * Tanh(`KEYS` + `PQ`))) - + `CONTEXT` - context based on `ALIGN` and `VALUES` CONTEXT = `ALIGN` * `VALUES` - - + + LSTM Cell: `X` - input tensor concat with attention state in the attention wrapper - + `i` - input gate - + `o` - output gate - + `f` - forget gate - + `c` - cell gate - + `t` - time step (t-1 means previous time step) - + `W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates - + `R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates - + `Wb[iofc]` - W bias vectors for input, output, forget, and cell gates - + `Rb[iofc]` - R bias vectors for input, output, forget, and cell gates - + `P[iof]` - P peephole weight vector for input, output, and forget gates - + `WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates - + `RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates - + `WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates - + `RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates - + `PB[iof]` - P peephole weight vector for backward input, output, and forget gates - + `H` - Hidden state - + `num_directions` - 2 if direction == bidirectional else 1 - + Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): - + - it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - + - ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - + - ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - + - Ct = ft (.) Ct-1 + it (.) ct - + - ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - + - Ht = ot (.) h(Ct) - - + + AttentionWrapp Notations: `lstm()' - wrapped inner cell. Ht, Ct = lstm(concat(Xt, ATTNt-1), Ct-1) - + `am()` - attention mechanism the wrapper used. CONTEXTt, ALIGNt = am(Ht, ALIGNt-1) - + `AW` - attention layer weights, optional. - + `ATTN` - attention state, initial is zero. If `AW` provided, it is the output of the attention layer, ATTNt = concat(Ht, CONTEXTt) * AW otherwise, ATTNt = CONTEXTt - + RNN layer output: `Y` - if needed is the sequence of Ht from lstm cell. - + `Y_h` - is the last valid H from lstm cell. - + `Y_c` - is the last valid C from lstm cell. - + #### Version @@ -519,7 +519,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.BiasGelu** Bias Gelu. - It's an extension of Gelu. It takes the sum of input A and bias input B as the input of Gelu activation. + It's an extension of Gelu. It takes the sum of input A and bias input B as the input of Gelu activation. #### Version @@ -711,7 +711,7 @@ This version of the operator has been available since version 1 of the 'com.micr ``` scale = 1. / (1. - ratio). ``` - + This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt that the mask is output as a bit-packed uint32 tensor, instead of a boolean tensor. #### Version @@ -1768,7 +1768,7 @@ This version of the operator has been available since version 1 of the 'com.micr which are used to interpolate the output value `output[n, :, h, w]`. The GridSample operator is often used in doing grid generator and sampler in the [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/master/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample). - + #### Version @@ -1883,10 +1883,10 @@ This version of the operator has been available since version 1 of the 'com.micr Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token attends to its W previous tokens and W succeeding tokens with W being the window length. A selected few tokens attend globally to all other tokens. - + The attention mask is of shape (batch_size, sequence_length), where sequence_length is a multiple of 2W after padding. Mask value < 0 (like -10000.0) means the token is masked, 0 otherwise. - + Global attention flags have value 1 for the tokens attend globally and 0 otherwise. #### Version @@ -2071,11 +2071,11 @@ This version of the operator has been available since version 1 of the 'com.micr Performs element-wise binary quantized multiplication (with Numpy-style broadcasting support). "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**" The output of this op is the int32 accumulated result of the mul operation - + ``` C (int32) = (A - A_zero_point) * (B - B_zero_point) ``` - + #### Version @@ -2114,7 +2114,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.MultiHeadAttention** Multi-Head Self/Cross Attention. Bias from input projection is included. - + The key padding mask is optional. When its shape is (batch_size, kv_sequence_length), value 0 means padding or 1 otherwise. When key has right-side padding, its shape could be (batch_size): it is actual length of each key sequence excluding paddings. @@ -2358,7 +2358,7 @@ This version of the operator has been available since version 1 of the 'com.micr [0.0, 0.0, 4.5, 5.7], ], ] - + #### Version @@ -2536,7 +2536,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QLinearAdd** Performs element-wise binary addition on 8 bit data types (with Numpy-style broadcasting support). - + C = (A_scale * (A - A_zero_point) + B_scale * (B - B_zero_point))/C_scale + C_zero_point #### Version @@ -2594,11 +2594,11 @@ This version of the operator has been available since version 1 of the 'com.micr output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1) ``` if ceil_mode is enabled - + ``` * pad_shape[i] is sum of pads along axis i ``` - + `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: ``` VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - kernel_spatial_shape[i] + 1) / strides_spatial_shape[i]) @@ -2608,9 +2608,9 @@ This version of the operator has been available since version 1 of the 'com.micr ``` pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i] ``` - + The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). - + Input and output scales and zero points are used to convert the output to a new quantization range. Output = Dequantize(Input) -> AveragePool on fp32 data -> Quantize(output) @@ -2878,7 +2878,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QLinearMul** Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting support). - + C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point #### Version @@ -2929,10 +2929,10 @@ This version of the operator has been available since version 1 of the 'com.micr with the exception that numpy default keepdims to False instead of True. Input and Output scales and zero points are used to requantize the output in a new range. This helps to improve accuracy as after ReduceMean operation the range of the output is expected to decrease. - + ``` "Output = Dequantize(Input) -> ReduceMean on fp32 data -> Quantize(output)", - + ``` #### Version @@ -2982,7 +2982,7 @@ This version of the operator has been available since version 1 of the 'com.micr QLinearSigmoid takes quantized input data (Tensor), and quantize parameter for output, and produces one output data (Tensor) where the function `f(x) = quantize(Sigmoid(dequantize(x)))`, is applied to the data tensor elementwise. - Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` + Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` #### Version @@ -3131,7 +3131,7 @@ This version of the operator has been available since version 1 of the 'com.micr left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. - Current version does not support past/present, extra_add and qkv_hidden_sizes. + Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. #### Version @@ -3196,7 +3196,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : Q
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
-
extra_add (optional) : S
+
relative_position_bias (optional) : S
additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
@@ -3767,10 +3767,10 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.RemovePadding** Compress transformer input by removing paddings. It assumes padding is on the right side of sequence. - + The input has padding with shape (batch_size, sequence_length, hidden_size). This will generate two outputs: output has shape (total_tokens, hidden_size); token_offset with shape (batch_size, sequence_length). - + token_offset has offsets of all non-padding tokens first, then offset of all padding tokens. It is a list of batch_size * sequence_length elements, which is reshaped to 2D for convenience of shape inference. @@ -3813,7 +3813,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.RestorePadding** Restore paddings and fill padding with zeros. - + The input has padding with shape (total_tokens, hidden_size) and token_offset with shape (batch_size, sequence_length). The output has shape (batch_size, sequence_length, hidden_size). @@ -4269,7 +4269,7 @@ This version of the operator has been available since version 1 of the 'com.micr Based on Torch operator Embedding, creates a lookup table of embedding vectors of fixed size, for a dictionary of fixed size. - + #### Version @@ -4359,7 +4359,7 @@ This version of the operator has been available since version 1 of the 'com.micr the main diagonal. A negative k value includes as many diagonals below the main diagonal. If upper is set to false, a positive k retains the lower triangular matrix including k diagonals above the main diagonal. A negative k value excludes as many diagonals below the main diagonal. - + #### Version @@ -4410,7 +4410,7 @@ This version of the operator has been available since version 1 of the 'com.micr output_uniques = [2, 1, 3, 4] output_idx = [0, 1, 1, 2, 3, 2] output_counts = [1, 2, 2, 1] - + #### Version @@ -4609,5 +4609,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
- - diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2dc4fbfb790b2..ada495b32842d 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -417,7 +417,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| @@ -785,7 +785,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| @@ -811,7 +811,7 @@ Do not modify directly.* |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* extra_add:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| +|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLongformerAttention|*in* input:**Q**
*in* scale_input:**S**
*in* weight:**Q**
*in* scale_weight:**S**
*in* bias:**S**
*in* scale_bias:**S**
*in* scale_qkv_gemm:**S**
*in* mask:**F**
*in* global_weight:**Q**
*in* scale_global_weight:**S**
*in* global_bias:**S**
*in* scale_global_gemm:**S**
*in* global:**G**
*in* scale_output:**S**
*out* output:**Q**|1+|**F** = tensor(float16)
**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| @@ -1156,7 +1156,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 47db3fe558ce8..6aa0e726afe1b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -198,7 +198,7 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_); @@ -208,7 +208,7 @@ Status Attention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, ¶meters)); const int batch_size = parameters.batch_size; @@ -331,7 +331,7 @@ Status Attention::Compute(OpKernelContext* context) const { return ApplyAttention(Q, K, V, mask_index, past, output, batch_size, sequence_length, parameters.head_size, parameters.v_head_size, parameters.v_hidden_size, - extra_add_qk, context); + relative_position_bias, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index affe7cab1d858..e75f68ea53c7c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -12,7 +12,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const Tensor* past_seq_len) const { // Abbreviation and Meanings: @@ -37,7 +37,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // extra_add_qk : (B, N, S, T) or NULL + // relative_position_bias : (B, N, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) @@ -49,9 +49,9 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger // than hidden dimension of Q, K and V. - if (past != nullptr && extra_add_qk != nullptr) { - // past is used on GPT-2 model with past state, we don't have a case for extra add qk yet - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and extra_add_qk"); + if (past != nullptr && relative_position_bias != nullptr) { + // past is used on GPT-2 model with past state, we don't have a case for relative position bias yet + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and relative_position_bias"); } const auto& dims = input_shape.GetDims(); @@ -191,34 +191,34 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } } - if (extra_add_qk != nullptr) { - const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims(); + if (relative_position_bias != nullptr) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - if (extra_add_qk_dims.size() != 4) { + if (relative_position_bias_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' is expected to have 4 dimensions, got ", - extra_add_qk_dims.size()); + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); } - if (extra_add_qk_dims[0] != batch_size) { + if (relative_position_bias_dims[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ", - extra_add_qk_dims[0]); + "Input 'relative_position_bias' dimension 0 should be same as batch_size, got ", + relative_position_bias_dims[0]); } - if (extra_add_qk_dims[1] != num_heads_) { + if (relative_position_bias_dims[1] != num_heads_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ", - extra_add_qk_dims[1]); + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); } - if (extra_add_qk_dims[2] != sequence_length) { + if (relative_position_bias_dims[2] != sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ", - extra_add_qk_dims[2]); + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); } - if (extra_add_qk_dims[3] != total_sequence_length) { + if (relative_position_bias_dims[3] != total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'extra_add_qk' dimension 3 should be same as total_sequence_length, got ", - extra_add_qk_dims[3]); + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); } } @@ -320,7 +320,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { @@ -328,7 +328,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, extra_add_qk, parameters, past_seq_len); + return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, relative_position_bias, parameters, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 2c49f196d52d8..2e077da2853d0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -18,7 +18,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, // for CUDA const Tensor* past_seq_len = nullptr) const; @@ -61,7 +61,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const Tensor* past_seq_len = nullptr) const; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 0185fa9ea09a0..70d71ffb6ee40 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -19,18 +19,18 @@ class AttentionCPUBase : public AttentionBase { : AttentionBase(info, require_same_hidden_size) {} template - Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH - const T* K, // K data with shape BxNxSxH - const T* V, // V value with size BxNxSxH_v - const Tensor* mask_index, // mask index. nullptr if no mask or its size is B - const Tensor* past, // past state - Tensor* output, // output tensor - int batch_size, // batch size (B) - int sequence_length, // sequence length (S) - int qk_head_size, // head size of Q or K (H) - int v_head_size, // head size of V (H_v) - int v_hidden_size, // hidden size of V (D_v) - const Tensor* extra_add_qk, // extra add in QK. Its size is BxNxSxT + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxSxH + const T* V, // V value with size BxNxSxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past, // past state + Tensor* output, // output tensor + int batch_size, // batch size (B) + int sequence_length, // sequence length (S) + int qk_head_size, // head size of Q or K (H) + int v_head_size, // head size of V (H_v) + int v_hidden_size, // hidden size of V (D_v) + const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT OpKernelContext* context) const { const int kv_sequence_length = sequence_length; @@ -67,16 +67,16 @@ class AttentionCPUBase : public AttentionBase { const T* past_data = past != nullptr ? past->Data() : nullptr; T* present_data = present != nullptr ? present->MutableData() : nullptr; - const T* extra_add_qk_data = nullptr; - if (extra_add_qk != nullptr) { - extra_add_qk_data = extra_add_qk->Data(); + const T* relative_position_bias_data = nullptr; + if (relative_position_bias != nullptr) { + relative_position_bias_data = relative_position_bias->Data(); } ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims, static_cast(mask_data), has_unidirectional, batch_size, sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, - past_data, present_data, tp, extra_add_qk_data); + past_data, present_data, tp, relative_position_bias_data); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -112,7 +112,7 @@ class AttentionCPUBase : public AttentionBase { const T* past, // past state T* present, // present state ThreadPool* tp, // thread pool - const T* extra_add_qk_data // extra add matrix with shape BxNxSxT + const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT ) const { const int total_sequence_length = past_sequence_length + sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H @@ -175,9 +175,9 @@ class AttentionCPUBase : public AttentionBase { } } - if (extra_add_qk_data != nullptr) { + if (relative_position_bias_data != nullptr) { for (int j = 0; j < sequence_length * total_sequence_length; j++) { - output[j] += extra_add_qk_data[output_offset + j]; + output[j] += relative_position_bias_data[output_offset + j]; } } } diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 64c17b7767e4f..e7df84c1b0066 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -160,7 +160,7 @@ Status QAttention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past_tensor, - nullptr, // extra_add_qk + nullptr, // relative_position_bias nullptr // parameters )); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 4a6d2dc137139..1ab89b525eae5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -59,7 +59,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(kPastInputIndex); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); @@ -69,7 +69,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -105,7 +105,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; bool use_causal_fused_runner = !disable_fused_runner_ && (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == extra_add_qk && + nullptr == relative_position_bias && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, @@ -125,7 +125,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { (nullptr == mask_index || is_mask_1d_seq_len) && nullptr == past && nullptr == present && - nullptr == extra_add_qk && + nullptr == relative_position_bias && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, enable_trt_flash_attention_, false); @@ -151,7 +151,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == mask_index && // TODO: support 1D mask nullptr == past && nullptr == present && - nullptr == extra_add_qk && + nullptr == relative_position_bias && (sizeof(T) == 2 || // sequence length threshold is 0 in FP16 parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2); @@ -203,7 +203,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); - data.extra_add_qk = (nullptr == extra_add_qk) ? nullptr : reinterpret_cast(extra_add_qk->Data()); + data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 15da6db4de583..6999a269caae5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -637,10 +637,10 @@ Status QkvToContext( bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. - DUMP_ATTENTION_D("extra_add_qk", data.extra_add_qk, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_ATTENTION_D("relative_position_bias", data.relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length); ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, data.extra_add_qk, scratch1, scratch2, + mask_index, nullptr, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); @@ -650,10 +650,10 @@ Status QkvToContext( const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, data.extra_add_qk, scratch1, scratch2, parameters.is_unidirectional)); + mask_index, mask_start, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( - ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.extra_add_qk, + ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, scratch1, scratch2, parameters.is_unidirectional)); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d98a0380c479b..2ecda71479c52 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -41,7 +41,7 @@ struct AttentionData { const int* mask_index; gsl::span mask_index_dims; const T* past; - const T* extra_add_qk; + const T* relative_position_bias; T* workspace; T* output; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 66aac47d9e3f6..a4b6ff116a3bd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -174,7 +174,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past = nullptr; - data.extra_add_qk = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index e5ea47a6a2a5b..7cd717efc9fba 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -52,7 +52,7 @@ Status QAttention::CheckInputs(const Tensor* input, auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor, - nullptr, // extra_add_qk + nullptr, // relative_position_bias parameters, device_prop.maxThreadsPerBlock)); @@ -198,7 +198,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); - data.extra_add_qk = nullptr; // add_qk is not supported in quantized attention + data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 204c786cc2c5d..8122b2de5916b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -212,7 +212,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), merged_weights_shape, merged_bias_shape, mask_index, nullptr, // past - nullptr, // extra_add_qk + nullptr, // relative_position_bias nullptr, // parameters device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h index 5fe62ef127800..5fb31be5fe86f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h @@ -17,4 +17,4 @@ DefineQOrderedAttentionInput(Scale_QK_Softmax, scale_QKT_softmax, 15), DefineQOrderedAttentionInput(Scale_Values_Gemm, scale_values_gemm, 16), DefineQOrderedAttentionInput(Mask_Index, mask_index, 17), DefineQOrderedAttentionInput(Past, past, 18), -DefineQOrderedAttentionInput(Extra_Add, extra_add, 19) +DefineQOrderedAttentionInput(relative_position_bias, relative_position_bias, 19) diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc index 756919834aef8..afc9fd9237ed7 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cc +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc @@ -39,7 +39,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* relative_position_bias = context->Input(5); auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), @@ -47,7 +47,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - extra_add_qk, + relative_position_bias, nullptr, device_prop.maxThreadsPerBlock)); @@ -129,7 +129,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), mask_filter_value_, nullptr == past ? nullptr : past->Data(), - nullptr == extra_add_qk ? nullptr : extra_add_qk->Data(), + nullptr == relative_position_bias ? nullptr : relative_position_bias->Data(), work_space.get(), output->MutableData(), nullptr == present ? nullptr : present->MutableData()); diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 954a129be1c65..fa6cce6a64132 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -89,7 +89,7 @@ Status QkvToContext( bool is_unidirectional, int past_sequence_length, const T* past, - const T* extra_add_qk, + const T* relative_position_bias, T* present, bool use_persistent_softmax) { const int all_sequence_length = past_sequence_length + sequence_length; @@ -158,7 +158,7 @@ Status QkvToContext( T* persistent_softmax_workspace = scratch1; // replace Q*K' in place if persistent softmax is selected. ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, extra_add_qk, scratch1, scratch2, + mask_index, nullptr, relative_position_bias, scratch1, scratch2, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index @@ -166,10 +166,10 @@ Status QkvToContext( // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)); + mask_index, mask_start, relative_position_bias, scratch1, scratch2, is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, - extra_add_qk, scratch1, scratch2, is_unidirectional)); + relative_position_bias, scratch1, scratch2, is_unidirectional)); } // compute P*V (as V*P), and store in scratch3: BxNxSxH @@ -206,7 +206,7 @@ Status LaunchAttentionKernel( gsl::span mask_index_dims, const float mask_filter_value, const void* past, - const void* extra_add_qk, + const void* relative_position_bias, void* workspace, void* output, void* present) { @@ -225,7 +225,7 @@ Status LaunchAttentionKernel( is_unidirectional, past_sequence_length, reinterpret_cast(past), - reinterpret_cast(extra_add_qk), + reinterpret_cast(relative_position_bias), reinterpret_cast<__half*>(present), use_persistent_softmax); } else { @@ -240,7 +240,7 @@ Status LaunchAttentionKernel( is_unidirectional, past_sequence_length, reinterpret_cast(past), - reinterpret_cast(extra_add_qk), + reinterpret_cast(relative_position_bias), reinterpret_cast(present), use_persistent_softmax); } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 7db692083f5e5..fdc46ce2e7729 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -42,7 +42,7 @@ Status LaunchAttentionKernel( gsl::span mask_index_dims, // Mask index shape const float mask_filter_value, // Mask value for filtered out positions const void* past, // Past state input - const void* extra_add_qk, // Additional Add + const void* relative_position_bias, // Additional Add void* workspace, // Temporary buffer void* output, // Output tensor void* present // Present state output diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 53128ff1f4f8c..104c706c420cd 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -234,7 +234,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .Input(5, - "extra_add", + "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", "T", OpSchema::Optional) diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 6111afbd5d817..91e4f5d8ff81a 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -1140,7 +1140,7 @@ where value of each element is the end position, or valid length of actual seque left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. -Current version does not support past/present, extra_add and qkv_hidden_sizes. +Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. )DOC"; @@ -1202,7 +1202,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(18, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "Q", OpSchema::Optional) - .Input(19, "extra_add", + .Input(19, "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "S", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "Q") diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index b4a92019992b5..c0a75fc50b07e 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -198,12 +198,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, - extra_add_qk, + relative_position_bias, parameters, max_threads_per_block, past_seq_len); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 2490789dd31a2..f12e080adf30a 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -145,7 +145,7 @@ struct ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) = 0; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index 63bae80c51a67..af93808248032 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -401,15 +401,15 @@ class DmlOperatorAttention : public DmlOperator void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) { *isSupported = false; - // Fall back to CPU if input 'past' and 'extra_add' is present because there is no current use case for this. + // Fall back to CPU if input 'past' and 'relative_position_bias' is present because there is no current use case for this. // and it will make the implementation more complex. // Also fall back to CPU if output 'present' is present for same reason as above. if (context->GetInputCount() > 4 || context->GetOutputCount() > 1) { return; } - // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'extra_add' is present - // because input 'mask_index', 'past', and 'extra_add' all are optional. + // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'relative_position_bias' is present + // because input 'mask_index', 'past', and 'relative_position_bias' all are optional. if (context->IsInputValid(4) || context->IsInputValid(5)) { return; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index b772fe95d6ecc..30be10ea7e15f 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -563,12 +563,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* extra_add_qk, + const Tensor* relative_position_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { return g_host_cpu.AttentionBase__CheckInputs(this, input_shape, weights_shape, bias_shape, - mask_index, past, extra_add_qk, parameters, + mask_index, past, relative_position_bias, parameters, max_threads_per_block, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, const Tensor* past, int batch_size, int head_size, diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 245ea9322ad61..342d43306e699 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -337,7 +337,7 @@ def create_attention_node( # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights. if self.use_multi_head_attention: if add_qk_str is not None: - logger.debug("MultiHeadAttention does not support extra_add_qk: cannot fuse the attention.") + logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.") return None attention_inputs = [ diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py index dc8f6810914a7..85e510a828990 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py @@ -172,8 +172,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add = k_nodes[-2] matmul = k_nodes[-1] - extra_add_qk_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0]) - if extra_add_qk_nodes is None: + relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0]) + if relative_position_bias_nodes is None: return if matmul.input[0] == root_input: @@ -189,7 +189,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.hidden_size, root_input, attention_last_node.output[0], - extra_add_qk_nodes[0].input[0], + relative_position_bias_nodes[0].input[0], ) if new_node is None: return diff --git a/onnxruntime/python/tools/transformers/onnx_model_tulr.py b/onnxruntime/python/tools/transformers/onnx_model_tulr.py index bede363a6c0c5..a03693ebff232 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tulr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tulr.py @@ -286,8 +286,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_k = None matmul_k = k_nodes[-1] - extra_add_qk_nodes = self.model.match_parent_path(add_qk, ["Mul"]) - if extra_add_qk_nodes is None: + relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Mul"]) + if relative_position_bias_nodes is None: return mask_nodes = self.model.match_parent_path( diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index fb1d8fcfe451a..0ea85dfdaba4f 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -59,7 +59,7 @@ static void RunAttentionTest( const bool disable_cuda = false, const bool disable_rocm = false, std::vector qkv_sizes = {}, - const std::vector& extra_add_data = {}, + const std::vector& relative_position_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false) { @@ -199,12 +199,12 @@ static void RunAttentionTest( } } - std::vector extra_add_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; - if (extra_add_data.size() > 0) { + std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + if (relative_position_bias_data.size() > 0) { if (use_float16) { - tester.AddInput("extra_add_qk", extra_add_data_dims, ToFloat16(extra_add_data)); + tester.AddInput("relative_position_bias", relative_position_bias_data_dims, ToFloat16(relative_position_bias_data)); } else { - tester.AddInput("extra_add_qk", extra_add_data_dims, extra_add_data); + tester.AddInput("relative_position_bias", relative_position_bias_data_dims, relative_position_bias_data); } } else { if (use_float16) { @@ -264,7 +264,7 @@ static void RunAttentionTest( const bool disable_cuda = false, const bool disable_rocm = false, const std::vector qkv_sizes = {}, - const std::vector& extra_add_data = {}, + const std::vector& relative_position_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false) { @@ -272,13 +272,13 @@ static void RunAttentionTest( batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data, + disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias_data, kv_sequence_length, past_present_share_buffer, use_scale); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data, + disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias_data, kv_sequence_length, past_present_share_buffer, use_scale); } @@ -390,7 +390,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { 0, false, false, disable_rocm, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1ExtraAdd) { +TEST(AttentionTest, AttentionBatch1RelativePositionBias) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -414,7 +414,7 @@ TEST(AttentionTest, AttentionBatch1ExtraAdd) { std::vector mask_index_data = {2L}; - std::vector extra_add_qk = { + std::vector relative_position_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; std::vector output_data = { @@ -427,10 +427,10 @@ TEST(AttentionTest, AttentionBatch1ExtraAdd) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_qk); + 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias); } -TEST(AttentionTest, AttentionBatch2ExtraAdd) { +TEST(AttentionTest, AttentionBatch2RelativePositionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -456,7 +456,7 @@ TEST(AttentionTest, AttentionBatch2ExtraAdd) { std::vector mask_index_data = {2L, 2L}; - std::vector extra_add_qk = { + std::vector relative_position_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f, 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; @@ -472,7 +472,7 @@ TEST(AttentionTest, AttentionBatch2ExtraAdd) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_qk); + 0, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, relative_position_bias); } TEST(AttentionTest, AttentionBatch1_Float16) { @@ -1709,7 +1709,7 @@ TEST(AttentionTest, AttentionWithNormFactor) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, {} /*qkv_sizes*/, - {} /*extra_add_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, + {} /*relative_position_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/); } diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index cf14a7f9918c6..5257fbdb08809 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -278,7 +278,7 @@ TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { test_qorder.AddInput("scale_values_gemm", {}, {attn_out_scale}, true); test_qorder.AddInput("mask_index", {batch_size, sequence_len}, input_mask.data(), input_mask.size()); test_qorder.AddOptionalInputEdge(); // past - test_qorder.AddOptionalInputEdge(); // extra_add + test_qorder.AddOptionalInputEdge(); // relative_position_bias test_qorder.AddOutput("output", {batch_size, sequence_len, hidden_size}, attn_out_q8.data(), attn_out_q8.size()); From 8fb1ca68b9affefb812253223e419ed05d092fc9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 31 Jan 2023 20:04:56 +0000 Subject: [PATCH 09/16] remove fusion script --- .../tools/transformers/onnx_model_tulr.py | 406 ------------------ .../python/tools/transformers/optimizer.py | 2 - 2 files changed, 408 deletions(-) delete mode 100644 onnxruntime/python/tools/transformers/onnx_model_tulr.py diff --git a/onnxruntime/python/tools/transformers/onnx_model_tulr.py b/onnxruntime/python/tools/transformers/onnx_model_tulr.py deleted file mode 100644 index a03693ebff232..0000000000000 --- a/onnxruntime/python/tools/transformers/onnx_model_tulr.py +++ /dev/null @@ -1,406 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -import logging -from typing import Union - -from fusion_attention import AttentionMask, FusionAttention -from fusion_utils import NumpyHelper -from onnx import NodeProto, TensorProto, helper, numpy_helper -from onnx_model import OnnxModel -from onnx_model_bert import BertOnnxModel -from fusion_base import Fusion -import numpy as np - -logger = logging.getLogger(__name__) - -#python optimizer.py --input /home/wy/Turing/tulr/model.onnx --output /home/wy/Turing/tulr/opt/model.onnx --model_type tulr --num_heads 16 --hidden_size 1024 --use_external_data_format - -class FusionTulrAttention(FusionAttention): - """ - Fuse TULR Attention subgraph into one Attention node. - """ - - def __init__( - self, - model: OnnxModel, - hidden_size: int, - num_heads: int, - attention_mask: AttentionMask, - ): - super().__init__(model, hidden_size, num_heads, attention_mask) - - def create_attention_node( - self, - mask_index: str, - q_matmul: NodeProto, - k_matmul: NodeProto, - v_matmul: NodeProto, - q_add: NodeProto, - k_add: NodeProto, - v_add: NodeProto, - num_heads: int, - hidden_size: int, - input: str, - output: str, - add_qk_str: str, - ) -> Union[NodeProto, None]: - """Create an Attention node. - - Args: - mask_index (str): mask input - q_matmul (NodeProto): MatMul node in fully connection for Q - k_matmul (NodeProto): MatMul node in fully connection for K - v_matmul (NodeProto): MatMul node in fully connection for V - q_add (NodeProto): Add bias node in fully connection for Q - k_add (NodeProto): Add bias node in fully connection for K - v_add (NodeProto): Add bias node in fully connection for V - num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. - hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. - input (str): input name - output (str): output name - - Returns: - Union[NodeProto, None]: the node created or None if failed. - """ - assert num_heads > 0 - - if hidden_size > 0 and (hidden_size % num_heads) != 0: - logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") - return None - - q_weight = self.model.get_initializer(q_matmul.input[1]) - k_weight = self.model.get_initializer(k_matmul.input[1]) - v_weight = self.model.get_initializer(v_matmul.input[1]) - q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0]) - #k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0]) - v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0]) - - if q_weight is None: - print( - f"{q_matmul.input[1]} is not an initializer. " - "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion" - ) - return None - if not (k_weight and v_weight and q_bias): - return None - - qw = NumpyHelper.to_array(q_weight) - kw = NumpyHelper.to_array(k_weight) - vw = NumpyHelper.to_array(v_weight) - - # assert q and k have same shape as expected - assert qw.shape == kw.shape - - qw_in_size = qw.shape[0] - kw_in_size = kw.shape[0] - vw_in_size = vw.shape[0] - - assert qw_in_size == kw_in_size == vw_in_size - - if hidden_size > 0 and hidden_size != qw_in_size: - logger.warning( - f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). " - "Please provide a correct input hidden size or pass in 0" - ) - - is_qkv_diff_dims = False - if qw.shape != vw.shape: - is_qkv_diff_dims = True - - # All the matrices can have the same shape or q, k matrics can have the same shape with v being different - # For 2d weights, the shapes would be [in_size, out_size]. - # For 3d weights, shape would be [in_size, a, b] where a*b = out_size - qw_out_size = np.prod(qw.shape[1:]) - kw_out_size = np.prod(kw.shape[1:]) - vw_out_size = np.prod(vw.shape[1:]) - - qkv_weight_dim = 0 - if is_qkv_diff_dims: - qkv_weight = np.concatenate((qw, kw, vw), axis=1) - qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size - else: - qkv_weight = np.stack((qw, kw, vw), axis=1) - qkv_weight_dim = 3 * qw_out_size - - qb = NumpyHelper.to_array(q_bias) - #kb = NumpyHelper.to_array(k_bias) - kb = np.zeros_like(qb) - vb = NumpyHelper.to_array(v_bias) - - q_bias_shape = np.prod(qb.shape) - k_bias_shape = q_bias_shape - v_bias_shape = np.prod(vb.shape) - - assert q_bias_shape == k_bias_shape == qw_out_size - assert v_bias_shape == vw_out_size - - qkv_bias_dim = 0 - if is_qkv_diff_dims: - qkv_bias = np.concatenate((qb, kb, vb), axis=0) - qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape - else: - qkv_bias = np.stack((qb, kb, vb), axis=0) - qkv_bias_dim = 3 * q_bias_shape - - attention_node_name = self.model.create_node_name("Attention") - - use_multi_head_attention = True - if not use_multi_head_attention: - weight = helper.make_tensor( - name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, - dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), - ) - - # Sometimes weights and bias are stored in fp16 - if q_weight.data_type == 10: - weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) - self.model.add_initializer(weight, self.this_graph_name) - - bias = helper.make_tensor( - name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, - dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), - ) - if q_bias.data_type == 10: - bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) - self.model.add_initializer(bias, self.this_graph_name) - - if use_multi_head_attention: - attention_inputs = [ - q_matmul.output[0], - k_matmul.output[0], - v_matmul.output[0], - attention_node_name + "_qkv_bias", - ] - if mask_index is not None: - attention_inputs.append(mask_index) - if add_qk_str is not None: - attention_inputs.append(add_qk_str) - - attention_node = helper.make_node( - "MultiHeadAttention", - inputs=attention_inputs, - outputs=[output], - name=attention_node_name, - ) - else: - attention_inputs = [ - input, - attention_node_name + "_qkv_weight", - attention_node_name + "_qkv_bias", - ] - if mask_index is not None: - attention_inputs.append(mask_index) - else: - attention_inputs.append("") - - if add_qk_str is not None: - attention_inputs.append("") # no past - attention_inputs.append(add_qk_str) - - attention_node = helper.make_node( - "Attention", - inputs=attention_inputs, - outputs=[output], - name=attention_node_name, - ) - attention_node.domain = "com.microsoft" - attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) - - if is_qkv_diff_dims: - attention_node.attribute.extend( - [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])] - ) - - if self.mask_filter_value is not None: - attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))]) - - return attention_node - - def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm - # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern - start_node = normalize_node - if normalize_node.op_type != "SkipLayerNormalization": - return - - # SkipLayerNormalization has two inputs, and one of them is the root input for attention. - qkv_nodes = self.model.match_parent_path( - start_node, - ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], - ) - if qkv_nodes is not None: - (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes - else: - return - - other_inputs = [] - for i, input in enumerate(start_node.input): - if input not in output_name_to_node: - continue - - if input == qkv_nodes[0].output[0]: - continue - other_inputs.append(input) - if len(other_inputs) != 1: - return - - root_input = other_inputs[0] - - v_nodes = self.model.match_parent_path( - matmul_qkv, - ["Transpose", "Reshape", "Add", "MatMul"], - ) - if v_nodes is None: - return - (_, _, add_v, matmul_v) = v_nodes - - qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Add", "Div", "MatMul"]) - if qk_nodes is None: - return - (_, add_qk, _, _, matmul_qk) = qk_nodes - - q_nodes = self.model.match_parent_path( - matmul_qk, - ["Transpose", "Reshape", "Add", "MatMul"], - [0, 0, 0, 1], - ) - if q_nodes is None: - return - add_q = q_nodes[-2] - matmul_q = q_nodes[-1] - - k_nodes = self.model.match_parent_path( - matmul_qk, - ["Transpose", "Reshape", "MatMul"], - [1, 0, 0], - ) - if k_nodes is None: - return - - add_k = None - matmul_k = k_nodes[-1] - - relative_position_bias_nodes = self.model.match_parent_path(add_qk, ["Mul"]) - if relative_position_bias_nodes is None: - return - - mask_nodes = self.model.match_parent_path( - add_qk, - ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], - ) - if len(mask_nodes) > 1 and mask_nodes[1].op_type == "Mul": - _, mul_val = self.model.get_constant_input(mask_nodes[0]) - if mul_val != -10000: - self.mask_filter_value = mul_val - - if matmul_q.input[0] == root_input: - mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) - # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads - # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately - new_node = self.create_attention_node( - mask_index, - matmul_q, - matmul_k, - matmul_v, - add_q, - add_k, - add_v, - self.num_heads, - self.hidden_size, - root_input, - reshape_qkv.output[0], - add_qk.input[1], - ) - if new_node is None: - return - self.nodes_to_add.append(new_node) - self.node_name_to_graph_name[new_node.name] = self.this_graph_name - - self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv]) - self.nodes_to_remove.extend(qk_nodes) - self.nodes_to_remove.extend(k_nodes[:-1]) - self.nodes_to_remove.extend(v_nodes[:-1]) - - # Use prune graph to remove mask nodes since they are shared by all attention nodes. - self.nodes_to_remove.extend(mask_nodes) - self.prune_graph = True - -# Attr("max_distance", "Max distance", AttributeProto::INT) -# Attr("is_bidirectional", "Default value is 0.", AttributeProto::INT, static_cast(0)) -# Input(0, "bias_table", "2D input tensor with shape (num_buckets, num_heads), COL-major(See UT for example)", "T") -# Input(1, "query_length", "The length of query. Self Attention requires query_length = key_length", "U") -# Input(2, "key_length", "The length of key.", "U") -# Output(0, "output", "4D output tensor with shape (1, num_heads, sequence_length, sequence_length)", "T") -class FusionRelativePositionBiasBlock(Fusion): - def __init__(self, model: OnnxModel, max_distance: int, is_bidirectional: bool): - super().__init__(model, "RelativePositionBias", "GatherElements") - self.max_distance = 128 - self.is_bidirectional = 1 - - def fuse(self, node, input_name_to_nodes, output_name_to_node): - stem_nodes = self.model.match_parent_path( - node, - ["Expand", "Where", "Equal", "Concat", "Unsqueeze", "Gather", "Shape", "Sub", "Unsqueeze", "Expand", "Unsqueeze", "Range"], - ) - if stem_nodes is None: - return - - expand = stem_nodes[0] - range = stem_nodes[-1] - rpb_nodes = self.model.match_parent_path( - expand, - ["Unsqueeze", "Unsqueeze", "Gemm"] - ) - if rpb_nodes is None: - return - - gemm = rpb_nodes[-1] - - self.nodes_to_remove.extend(stem_nodes) - self.nodes_to_remove.extend(rpb_nodes) - - table_weight = self.model.get_initializer(gemm.input[0]) - table_weight_np = NumpyHelper.to_array(table_weight) - bias_table = helper.make_tensor( - name="bias_table_weight", - data_type=TensorProto.FLOAT, - dims=[np.shape(table_weight_np)[1], np.shape(table_weight_np)[0]], - vals=table_weight_np.flatten().tolist(), - ) - self.model.add_initializer(bias_table, self.this_graph_name) - inputs = [bias_table.name, range.input[1], range.input[1]] - outputs = [node.output[0]] - rpb_node = helper.make_node( - "RelativePositionBias", - inputs=inputs, - outputs=outputs, - name=self.model.create_node_name("RelativePositionBias", name_prefix="RPB"), - ) - rpb_node.domain = "com.microsoft" - rpb_node.attribute.extend([helper.make_attribute("max_distance", self.max_distance)]) - rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", self.is_bidirectional)]) - - self.nodes_to_add.append(rpb_node) - self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name - - -class TulrOnnxModel(BertOnnxModel): - def __init__(self, model, num_heads, hidden_size): - super().__init__(model, num_heads, hidden_size) - self.attention_mask = AttentionMask(self) - self.attention_fusion = FusionTulrAttention(self, self.hidden_size, self.num_heads, self.attention_mask) - self.rpb_fusion = FusionRelativePositionBiasBlock(self, 32, True) - - def fuse_attention(self): - self.attention_fusion.apply() - - def postprocess(self): - self.rpb_fusion.apply() - self.clean_graph() - self.prune_graph() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 3c3df6d251971..56076eedda78a 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -31,7 +31,6 @@ from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_tnlr import TnlrOnnxModel -from onnx_model_tulr import TulrOnnxModel from onnx_model_unet import UnetOnnxModel logger = logging.getLogger(__name__) @@ -49,7 +48,6 @@ 0, ), # might add a class for GPT2OnnxModel for TF later. "tnlr": (TnlrOnnxModel, "pytorch", 1), - "tulr": (TulrOnnxModel, "pytorch", 1), "unet": (UnetOnnxModel, "pytorch", 1), } From 0d44a7b394801fe1294dfbf71d685aceb852350d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 31 Jan 2023 23:40:52 +0000 Subject: [PATCH 10/16] fix docs --- docs/ContribOperators.md | 184 ++++++++++++++++++++------------------- docs/OperatorKernels.md | 2 +- 2 files changed, 95 insertions(+), 91 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index a8d2079376ef4..241ff7d718fd0 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -94,24 +94,24 @@ Do not modify directly.* ### **com.microsoft.Attention** Multi-Head Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). - + The weights for input projection of Q, K and V are merged. The data is stacked on the second dimension. Its shape is (input_hidden_size, hidden_size + hidden_size + v_hidden_size). Here hidden_size is the hidden dimension of Q and K, and v_hidden_size is that of V. - + The mask_index is optional. Besides raw attention mask with shape (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) with value 0 for masked and 1 otherwise, we support other two formats: When input has right-side padding, mask_index is one dimension with shape (batch_size), where value is actual sequence length excluding padding. When input has left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. - + When unidirectional is 1, each token only attends to previous tokens. - + Both past and present state are optional. They shall be used together, and not allowed to use only one of them. The qkv_hidden_sizes is required only when K and V have different hidden sizes. - + When there is past state, hidden dimension for Q, K and V shall be the same. - + The total_sequence_length is past_sequence_length + kv_sequence_length. Here kv_sequence_length is the length of K or V. For self attention, kv_sequence_length equals to sequence_length (sequence length of Q). For cross attention, query and key might have different lengths. @@ -179,133 +179,133 @@ This version of the operator has been available since version 1 of the 'com.micr Computes an one-layer RNN where its RNN Cell is an AttentionWrapper wrapped a LSTM Cell. The RNN layer contains following basic component: LSTM Cell, Bahdanau Attention Mechanism, AttentionWrapp. - + Activation functions: - + Relu(x) - max(0, x) - + Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) - + Sigmoid(x) - 1/(1 + e^{-x}) - + (NOTE: Below are optional) - + Affine(x) - alpha*x + beta - + LeakyRelu(x) - x if x >= 0 else alpha * x - + ThresholdedRelu(x) - x if x >= alpha else 0 - + ScaledTanh(x) - alpha*Tanh(beta*x) - + HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) - + Elu(x) - x if x >= 0 else alpha*(e^x - 1) - + Softsign(x) - x/(1 + |x|) - + Softplus(x) - log(1 + e^x) - + Softmax(x) - exp(x) / sum(exp(x)) - + Bahdanau Attention Mechanism: `M` - Memory tensor. - + `VALUES` - masked Memory by its real sequence length. - + `MW` - Memory layer weight. - + `KEYS` - Processed memory tensor by the memory layer. KEYS = M * MW - + `Query` - Query tensor, normally at specific time step in sequence. - + `QW` - Query layer weight in the attention mechanism - + `PQ` - processed query, = `Query` * `QW` - + `V' - attention vector - + `ALIGN` - calculated alignment based on Query and KEYS ALIGN = softmax(reduce_sum(`V` * Tanh(`KEYS` + `PQ`))) - + `CONTEXT` - context based on `ALIGN` and `VALUES` CONTEXT = `ALIGN` * `VALUES` - - + + LSTM Cell: `X` - input tensor concat with attention state in the attention wrapper - + `i` - input gate - + `o` - output gate - + `f` - forget gate - + `c` - cell gate - + `t` - time step (t-1 means previous time step) - + `W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates - + `R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates - + `Wb[iofc]` - W bias vectors for input, output, forget, and cell gates - + `Rb[iofc]` - R bias vectors for input, output, forget, and cell gates - + `P[iof]` - P peephole weight vector for input, output, and forget gates - + `WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates - + `RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates - + `WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates - + `RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates - + `PB[iof]` - P peephole weight vector for backward input, output, and forget gates - + `H` - Hidden state - + `num_directions` - 2 if direction == bidirectional else 1 - + Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): - + - it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - + - ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - + - ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - + - Ct = ft (.) Ct-1 + it (.) ct - + - ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - + - Ht = ot (.) h(Ct) - - + + AttentionWrapp Notations: `lstm()' - wrapped inner cell. Ht, Ct = lstm(concat(Xt, ATTNt-1), Ct-1) - + `am()` - attention mechanism the wrapper used. CONTEXTt, ALIGNt = am(Ht, ALIGNt-1) - + `AW` - attention layer weights, optional. - + `ATTN` - attention state, initial is zero. If `AW` provided, it is the output of the attention layer, ATTNt = concat(Ht, CONTEXTt) * AW otherwise, ATTNt = CONTEXTt - + RNN layer output: `Y` - if needed is the sequence of Ht from lstm cell. - + `Y_h` - is the last valid H from lstm cell. - + `Y_c` - is the last valid C from lstm cell. - + #### Version @@ -519,7 +519,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.BiasGelu** Bias Gelu. - It's an extension of Gelu. It takes the sum of input A and bias input B as the input of Gelu activation. + It's an extension of Gelu. It takes the sum of input A and bias input B as the input of Gelu activation. #### Version @@ -711,7 +711,7 @@ This version of the operator has been available since version 1 of the 'com.micr ``` scale = 1. / (1. - ratio). ``` - + This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt that the mask is output as a bit-packed uint32 tensor, instead of a boolean tensor. #### Version @@ -1768,7 +1768,7 @@ This version of the operator has been available since version 1 of the 'com.micr which are used to interpolate the output value `output[n, :, h, w]`. The GridSample operator is often used in doing grid generator and sampler in the [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/master/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample). - + #### Version @@ -1883,10 +1883,10 @@ This version of the operator has been available since version 1 of the 'com.micr Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token attends to its W previous tokens and W succeeding tokens with W being the window length. A selected few tokens attend globally to all other tokens. - + The attention mask is of shape (batch_size, sequence_length), where sequence_length is a multiple of 2W after padding. Mask value < 0 (like -10000.0) means the token is masked, 0 otherwise. - + Global attention flags have value 1 for the tokens attend globally and 0 otherwise. #### Version @@ -2071,11 +2071,11 @@ This version of the operator has been available since version 1 of the 'com.micr Performs element-wise binary quantized multiplication (with Numpy-style broadcasting support). "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**" The output of this op is the int32 accumulated result of the mul operation - + ``` C (int32) = (A - A_zero_point) * (B - B_zero_point) ``` - + #### Version @@ -2114,7 +2114,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.MultiHeadAttention** Multi-Head Self/Cross Attention. Bias from input projection is included. - + The key padding mask is optional. When its shape is (batch_size, kv_sequence_length), value 0 means padding or 1 otherwise. When key has right-side padding, its shape could be (batch_size): it is actual length of each key sequence excluding paddings. @@ -2132,7 +2132,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (4 - 5) +#### Inputs (4 - 6)
query : T
@@ -2145,6 +2145,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
+
relative_position_bias (optional) : T
+
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
#### Outputs @@ -2358,7 +2360,7 @@ This version of the operator has been available since version 1 of the 'com.micr [0.0, 0.0, 4.5, 5.7], ], ] - + #### Version @@ -2536,7 +2538,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QLinearAdd** Performs element-wise binary addition on 8 bit data types (with Numpy-style broadcasting support). - + C = (A_scale * (A - A_zero_point) + B_scale * (B - B_zero_point))/C_scale + C_zero_point #### Version @@ -2594,11 +2596,11 @@ This version of the operator has been available since version 1 of the 'com.micr output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1) ``` if ceil_mode is enabled - + ``` * pad_shape[i] is sum of pads along axis i ``` - + `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: ``` VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - kernel_spatial_shape[i] + 1) / strides_spatial_shape[i]) @@ -2608,9 +2610,9 @@ This version of the operator has been available since version 1 of the 'com.micr ``` pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i] ``` - + The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). - + Input and output scales and zero points are used to convert the output to a new quantization range. Output = Dequantize(Input) -> AveragePool on fp32 data -> Quantize(output) @@ -2878,7 +2880,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QLinearMul** Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting support). - + C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point #### Version @@ -2929,10 +2931,10 @@ This version of the operator has been available since version 1 of the 'com.micr with the exception that numpy default keepdims to False instead of True. Input and Output scales and zero points are used to requantize the output in a new range. This helps to improve accuracy as after ReduceMean operation the range of the output is expected to decrease. - + ``` "Output = Dequantize(Input) -> ReduceMean on fp32 data -> Quantize(output)", - + ``` #### Version @@ -2982,7 +2984,7 @@ This version of the operator has been available since version 1 of the 'com.micr QLinearSigmoid takes quantized input data (Tensor), and quantize parameter for output, and produces one output data (Tensor) where the function `f(x) = quantize(Sigmoid(dequantize(x)))`, is applied to the data tensor elementwise. - Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` + Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` #### Version @@ -3767,10 +3769,10 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.RemovePadding** Compress transformer input by removing paddings. It assumes padding is on the right side of sequence. - + The input has padding with shape (batch_size, sequence_length, hidden_size). This will generate two outputs: output has shape (total_tokens, hidden_size); token_offset with shape (batch_size, sequence_length). - + token_offset has offsets of all non-padding tokens first, then offset of all padding tokens. It is a list of batch_size * sequence_length elements, which is reshaped to 2D for convenience of shape inference. @@ -3813,7 +3815,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.RestorePadding** Restore paddings and fill padding with zeros. - + The input has padding with shape (total_tokens, hidden_size) and token_offset with shape (batch_size, sequence_length). The output has shape (batch_size, sequence_length, hidden_size). @@ -4269,7 +4271,7 @@ This version of the operator has been available since version 1 of the 'com.micr Based on Torch operator Embedding, creates a lookup table of embedding vectors of fixed size, for a dictionary of fixed size. - + #### Version @@ -4359,7 +4361,7 @@ This version of the operator has been available since version 1 of the 'com.micr the main diagonal. A negative k value includes as many diagonals below the main diagonal. If upper is set to false, a positive k retains the lower triangular matrix including k diagonals above the main diagonal. A negative k value excludes as many diagonals below the main diagonal. - + #### Version @@ -4410,7 +4412,7 @@ This version of the operator has been available since version 1 of the 'com.micr output_uniques = [2, 1, 3, 4] output_idx = [0, 1, 1, 2, 3, 2] output_counts = [1, 2, 2, 1] - + #### Version @@ -4609,3 +4611,5 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
+ + diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ada495b32842d..d1a83213e8eeb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -808,7 +808,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| From da875fa2474a2e81d413a60bbf3270f1ca358a00 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Thu, 2 Feb 2023 19:02:38 -0800 Subject: [PATCH 11/16] Zhalei/gated rel position bias (#14553) ### Description ### Motivation and Context --------- Co-authored-by: Lei Zhang --- .../cuda/bert/relative_attn_bias.cc | 121 +++++++- .../cuda/bert/relative_attn_bias.h | 12 + .../cuda/bert/relative_attn_bias_impl.cu | 116 ++++++++ .../cuda/bert/relative_attn_bias_impl.h | 15 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 + .../core/graph/contrib_ops/bert_defs.cc | 34 +++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../relative_attention_bias_test.cc | 266 +++++++++++++++++- 8 files changed, 566 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index af13efe0e2fbc..c08bcf35d076a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -3,7 +3,15 @@ #include "core/providers/cuda/cuda_common.h" #include "relative_attn_bias.h" +#include "core/common/safeint.h" #include "relative_attn_bias_impl.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + namespace onnxruntime { namespace contrib { @@ -20,7 +28,16 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - RelPosAttnBias); + RelPosAttnBias); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GatedRelativePositionBias, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + GatedRelativePositionBias); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) @@ -69,6 +86,108 @@ Status RelPosAttnBias::ComputeInternal(OpKernelContext* context) const { device_prop.maxThreadsPerBlock); } +template +GatedRelativePositionBias::GatedRelativePositionBias(const OpKernelInfo& info) : CudaKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = SafeInt(num_heads); +} + +template +Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) const { + const Tensor& query_tensor = *context->Input(0); + const Tensor& query_bias_tensor = *context->Input(1); + const Tensor& rel_pos_tensor = *context->Input(2); + const Tensor& weight_tensor = *context->Input(3); + const Tensor& bias_tensor = *context->Input(4); + const Tensor& eco_a_tensor = *context->Input(5); + + const auto& query_dims = query_tensor.Shape().GetDims(); + ORT_ENFORCE(query_dims.size() == 3); + ORT_ENFORCE(query_dims[2] > 0); + ORT_ENFORCE(query_dims[2] % num_heads_ == 0); + const auto batch_size = SafeInt(query_dims[0]); + const auto seq_len = SafeInt(query_dims[1]); + const auto head_size = SafeInt(query_dims[2] / num_heads_); + + ORT_ENFORCE(query_bias_tensor.Shape().NumDimensions() == 1); + ORT_ENFORCE(query_bias_tensor.Shape()[0] == query_dims[2]); + + const auto& rel_pos_dims = rel_pos_tensor.Shape().GetDims(); + ORT_ENFORCE(rel_pos_dims.size() == 4); + ORT_ENFORCE(rel_pos_dims[0] == 1); + ORT_ENFORCE(rel_pos_dims[1] == num_heads_); + ORT_ENFORCE(rel_pos_dims[2] == seq_len); + ORT_ENFORCE(rel_pos_dims[3] == seq_len); + + const auto& weight_dims = weight_tensor.Shape().GetDims(); + ORT_ENFORCE(weight_dims.size() == 2); + ORT_ENFORCE(weight_dims[0] == head_size); + ORT_ENFORCE((weight_dims[1] > 0) && (weight_dims[1] % 2 == 0)); + + ORT_ENFORCE(bias_tensor.Shape().NumDimensions() == 1); + ORT_ENFORCE(bias_tensor.Shape()[0] == weight_dims[1]); + + const auto D = SafeInt(weight_dims[1]); + + const auto& eco_a_dims = eco_a_tensor.Shape().GetDims(); + ORT_ENFORCE(eco_a_dims.size() == 4); + ORT_ENFORCE(eco_a_dims[0] == 1); + ORT_ENFORCE(eco_a_dims[1] == num_heads_); + ORT_ENFORCE(eco_a_dims[2] == 1); + ORT_ENFORCE(eco_a_dims[3] == 1); + + Tensor* output = context->Output(0, {batch_size, num_heads_, seq_len, seq_len}); + + auto& device_prop = GetDeviceProp(); + cublasHandle_t cublas = GetCublasHandle(context); + + typedef typename ToCudaType::MappedType CudaT; + const auto BNS = batch_size * num_heads_ * seq_len; + const size_t elements_in_query = (size_t)BNS * (size_t)head_size; + const size_t elements_after_gemm = (size_t)BNS *(size_t)D; + size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0); + auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream()); + + // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH) + constexpr int format = 1; + constexpr int total_maxtrix = 1; + constexpr int num_matrix_to_transpose = 1; + LaunchAddBiasTranspose(Stream(context), num_matrix_to_transpose, format, device_prop.maxThreadsPerBlock, + batch_size, seq_len, num_heads_, head_size, + reinterpret_cast(query_tensor.template Data()), + reinterpret_cast(query_bias_tensor.template Data()), + reinterpret_cast(workspace.get()), + false, head_size, reinterpret_cast(nullptr), total_maxtrix); + + // reuse output if possible + CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast(workspace.get()) + elements_in_query) + : reinterpret_cast(output->template MutableData()); + int ld_gemm_output = max(seq_len, D); + + const CudaT one = ToCudaType::FromFloat(1.0f); + const CudaT zero = ToCudaType::FromFloat(0.0f); + + // ([b*n*s, h] * [h, D]), CUDA assumes col-major + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + D, BNS, head_size, &one, + reinterpret_cast(weight_tensor.template Data()), (int)D, + reinterpret_cast(workspace.get()), (int)head_size, + &zero, gemm_output, ld_gemm_output, device_prop)); + + auto status = LaunchGatedRelativePositionBiasKernel( + device_prop, Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(rel_pos_tensor.template Data()), + reinterpret_cast(gemm_output), + reinterpret_cast(bias_tensor.template Data()), + reinterpret_cast(eco_a_tensor.template Data()), + batch_size, num_heads_, seq_len, D, ld_gemm_output); + + return status; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h index b9674f6f35091..3bf4e730e29f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.h @@ -22,6 +22,18 @@ class RelPosAttnBias final : public CudaKernel { bool is_bidirectional_; }; +template +class GatedRelativePositionBias final : public CudaKernel { + public: + GatedRelativePositionBias(const OpKernelInfo& op_kernel_info); + + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int num_heads_; +}; + + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index 308ef770fd857..938496b058025 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -149,6 +149,122 @@ template Status LaunchRelPosAttnBiasKernel(cudaStream_t stream, const bool is_bidirectional, const int max_threads_per_block); +template +__global__ void GatedRelativePositionBiasKernelSmallD( + T* output, // (batch_size, num_heads, seq_len, seq_len) + const T* rel_pos, // (1, num_heads, seq_len, seq_len) + const T* qw, // (batch_size, num_heads, seq_len, D) + const T* bias, // (D) + const T* eco_a, // (1, num_heads, 1, 1) + const int D, + const int ldqw) { + __shared__ float gate[1]; + + const int seq_len = gridDim.x; + const int num_heads = gridDim.y; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + rel_pos += ((int64_t)n * seq_len + s) * seq_len; + output += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * seq_len; + qw += ((int64_t)b * num_heads * seq_len + (int64_t)n * seq_len + s) * ldqw; + + float val = 0.0f; + if (threadIdx.x < D) { + val = (float)qw[threadIdx.x] + (bias ? (float)bias[threadIdx.x] : 0.0f); + } + + float u = (threadIdx.x < D / 2) ? val : 0.0f; +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + u += __shfl_down_sync(0xffffffff, u, offset); + } + + float r = (threadIdx.x >= D / 2) ? val : 0.0f; +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + r += __shfl_down_sync(0xffffffff, r, offset); + } + + if (threadIdx.x == 0) { + u = 1.0f / (1.0f + expf(-u)); + r = 1.0f / (1.0f + expf(-r)); + gate[0] = u * (r * (float)eco_a[n] - 1.0f) + 2.0f; + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < seq_len; idx += blockDim.x) { + output[idx] = (T)(gate[0] * (float)rel_pos[idx]); + } +} + +template +Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + T* output, + const T* rel_pos, + const T* qw, // query * weight + const T* bias, + const T* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw) { + ORT_ENFORCE(D <= 32 && D > 0 && (D % 2 == 0)); + ORT_ENFORCE(ldqw == seq_len || ldqw == D); + + int tpb = std::max(32, std::max(D, seq_len)); + tpb = std::min(tpb, device_prop.maxThreadsPerBlock); + + // round up tpb to power of 2 + --tpb; + tpb |= (tpb >> 1); + tpb |= (tpb >> 2); + tpb |= (tpb >> 4); + tpb |= (tpb >> 8); + tpb |= (tpb >> 16); + tpb++; + + dim3 block(tpb); + dim3 grid(seq_len, num_heads, batch_size); + + GatedRelativePositionBiasKernelSmallD<<>>( + output, rel_pos, qw, bias, eco_a, D, ldqw); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + float* output, + const float* rel_pos, + const float* qw, + const float* bias, + const float* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + +template Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + half* output, + const half* rel_pos, + const half* qw, + const half* bias, + const half* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h index 5a1a229ab6077..5c7c98f55f3f5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.h @@ -22,6 +22,21 @@ Status LaunchRelPosAttnBiasKernel( const int max_threads_per_block ); +template +Status LaunchGatedRelativePositionBiasKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + T* output, + const T* rel_pos, + const T* qw, // from query * weight + const T* bias, + const T* eco_a, + const int batch_size, + const int num_heads, + const int seq_len, + const int D, + const int ldqw); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 38bcbc298b939..e9c2b4ba89dfb 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -30,6 +30,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding); @@ -155,6 +157,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 104c706c420cd..bdab663a9954f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -663,5 +663,39 @@ ONNX_MS_OPERATOR_SET_SCHEMA( RestorePaddingTypeAndShapeInference(ctx); })); +constexpr const char* GatedRelativePositionBias_ver1_doc = R"DOC( + query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) + gate_u, gate_r = torch.sigmoid( + self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0 + rel_pos_bias = gate_u_1 * rel_pos +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GatedRelativePositionBias, 1, + OpSchema() + .SetDoc(GatedRelativePositionBias_ver1_doc) + .Attr("num_heads", "Number of attention heads", AttributeProto::INT) + .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T") + .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T") + .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T") + .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T") + .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T") + .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T") + .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + int64_t num_heads = getAttribute(ctx, "num_heads", -1L); + auto& query_layer_shape = getInputShape(ctx, 0); + TensorShapeProto output_shape; + *output_shape.add_dim() = query_layer_shape.dim(0); + output_shape.add_dim()->set_dim_value(num_heads); + *output_shape.add_dim() = query_layer_shape.dim(1); + *output_shape.add_dim() = query_layer_shape.dim(1); + updateOutputShape(ctx, 0, output_shape); + })); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 1f0af31a4bdd0..a272af1e6998e 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -79,6 +79,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RelativePositionBias); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBias); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); @@ -167,6 +168,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc index 7722291bee653..646e18605f118 100644 --- a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc +++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc @@ -10,9 +10,9 @@ namespace onnxruntime { namespace test { static void RunRelativePositionBiasTest( - const std::vector& bias_table, // Shape = [num_buckets, num_heads] - const std::vector& sequence_length, // Shape = [1] - const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length] + const std::vector& bias_table, // Shape = [num_buckets, num_heads] + const std::vector& sequence_length, // Shape = [1] + const std::vector& output_data, // Shape = [1, num_heads, sequence_length, sequence_length] int max_distance, int num_buckets, int num_heads, @@ -155,5 +155,265 @@ TEST(RelativePositionBiasTest, RelativePositionBiasTest_FP16_No_Bidirectional) { true); } +/***************Following scripts is used to generate test data, for your reference************* +import torch + +batch_size = 2 +num_heads = 2 +seq_len = 3 +head_size = 4 +D = 8 + +def dim_string_of(tensor): + return "{" + ", ".join([str(d) for d in tensor.shape]) + "}" + +def value_string_of(tensor): + arr = tensor.flatten().numpy() + lines = ["f, ".join([str(v) for v in arr[i : min(i+8, arr.size)]]) for i in range(0, arr.size, 8)] + return "{\n " + "f,\n ".join(lines) + "f}" + +def print_tensor(name, tensor): + print(f"const std::vector {name}_dim = {dim_string_of(tensor)};") + print(f"const std::vector {name} = {value_string_of(tensor)};") + +torch.manual_seed(0) +query_layer = torch.rand(batch_size, seq_len, num_heads * head_size) +query_bias = torch.rand(num_heads * head_size) +rel_pos = torch.rand(1, num_heads, seq_len, seq_len) +weight = torch.rand(head_size, D) +bias = torch.rand(D) +eco_a = torch.rand(1, num_heads, 1, 1) + +qw = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) +gate_u,gate_r = torch.sigmoid( + (torch.matmul(qw, weight) + bias).view(batch_size, num_heads, seq_len,2, D//2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) +gate_u_1 = gate_u * (gate_r * eco_a - 1.0) + 2.0 +output = gate_u_1 * rel_pos + +# output for test case +print(f"const int batch_size = {batch_size};") +print(f"const int num_heads = {num_heads};") +print(f"const int seq_len = {seq_len};") +print(f"const int head_size = {head_size};") +print(f"const int D = {D};") + +print_tensor("query_layer", query_layer) +print_tensor("query_bias", query_bias) +print_tensor("rel_pos", rel_pos) +print_tensor("weight", weight) +print_tensor("bias", bias) +print_tensor("eco_a", eco_a) +print_tensor("output", output) +****************/ + +// .Input(0, "query_layer", "tensor with shape (batch_size, seq_len, num_heads x head_size)", "T") +// .Input(1, "query_bias", "1-d tensor with shape (num_heads x head_size)", "T") +// .Input(2, "rel_pos", "tensor with shape (1, num_head, seq_len, seq_len)", "T") +// .Input(3, "weight", "gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2", "T") +// .Input(4, "bias", "bias for the gated_ur_linear, shape (D)", "T") +// .Input(5, "eco_a", "tensor of shape (1, num_heads, 1, 1)", "T") +// .Output(0, "output", "output tensor with shape (batch_size, num_heads, seq_len, seq_len)", "T") +static void RunGatedRelativePositionBiasTest( + const std::vector& query_layer, + const std::vector& query_bias, + const std::vector& rel_pos, + const std::vector& weight, + const std::vector& bias, + const std::vector& eco_a, + const std::vector& output, + int batch_size, + int seq_len, + int num_heads, + int head_size, + int D, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = false; + if (enable_cuda) { + OpTester tester("GatedRelativePositionBias", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + + std::vector query_layer_dims = {batch_size, seq_len, num_heads * head_size}; + std::vector query_bias_dims = {num_heads * head_size}; + std::vector rel_pos_dims = {1, num_heads, seq_len, seq_len}; + std::vector weight_dims = {head_size, D}; + std::vector bias_dims = {D}; + std::vector eco_a_dims = {1, num_heads, 1, 1}; + std::vector output_dims = {batch_size, num_heads, seq_len, seq_len}; + + if (use_float16) { + tester.AddInput("query_layer", query_layer_dims, ToFloat16(query_layer)); + tester.AddInput("query_bias", query_bias_dims, ToFloat16(query_bias)); + tester.AddInput("rel_pos", rel_pos_dims, ToFloat16(rel_pos)); + tester.AddInput("weight", weight_dims, ToFloat16(weight)); + tester.AddInput("bias", bias_dims, ToFloat16(bias)); + tester.AddInput("eco_a", eco_a_dims, ToFloat16(eco_a)); + tester.AddOutput("output", output_dims, ToFloat16(output)); + } else { + tester.AddInput("query_layer", query_layer_dims, query_layer); + tester.AddInput("query_bias", query_bias_dims, query_bias); + tester.AddInput("rel_pos", rel_pos_dims, rel_pos); + tester.AddInput("weight", weight_dims, weight); + tester.AddInput("bias", bias_dims, bias); + tester.AddInput("eco_a", eco_a_dims, eco_a); + tester.AddOutput("output", output_dims, output); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) { + const int batch_size = 1; + const int num_heads = 2; + const int seq_len = 3; + const int head_size = 4; + const int D = 8; + const std::vector query_layer_dim = {1, 3, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f}; + const std::vector rel_pos_dim = {1, 2, 3, 3}; + const std::vector rel_pos = { + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f, + 0.27825248f, 0.4819588f}; + const std::vector weight_dim = {4, 8}; + const std::vector weight = { + 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f, 0.593172f, 0.112347245f, + 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, 0.774486f, 0.43689132f, + 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, 0.69650495f, 0.9142747f, + 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f, 0.034022927f, 0.94424623f}; + const std::vector bias_dim = {8}; + const std::vector bias = { + 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f, 0.6922781f, 0.20384824f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.68329567f, 0.75285405f}; + const std::vector output_dim = {1, 2, 3, 3}; + const std::vector output = { + 0.29608122f, 0.45416728f, 0.25361493f, 0.053390637f, 0.3503264f, 1.5650483f, 1.2171557f, 1.2495192f, + 0.88587445f, 0.42708054f, 1.0246648f, 0.05810945f, 0.2430356f, 0.4244021f, 1.428723f, 1.3902748f, + 0.48772895f, 0.84479123f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, true); +} + +TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) { + const int batch_size = 2; + const int num_heads = 2; + const int seq_len = 3; + const int head_size = 4; + const int D = 8; + const std::vector query_layer_dim = {2, 3, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f, + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f, + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.27825248f, 0.4819588f, 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f}; + const std::vector rel_pos_dim = {1, 2, 3, 3}; + const std::vector rel_pos = { + 0.593172f, 0.112347245f, 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, + 0.774486f, 0.43689132f, 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, + 0.69650495f, 0.9142747f}; + const std::vector weight_dim = {4, 8}; + const std::vector weight = { + 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f, 0.034022927f, 0.94424623f, + 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f, 0.6922781f, 0.20384824f, + 0.68329567f, 0.75285405f, 0.8579358f, 0.6869556f, 0.005132377f, 0.17565155f, 0.7496575f, 0.6046507f, + 0.10995799f, 0.21209025f, 0.97037464f, 0.83690894f, 0.28198743f, 0.3741576f, 0.023700953f, 0.49101293f}; + const std::vector bias_dim = {8}; + const std::vector bias = { + 0.123470545f, 0.11432165f, 0.4724502f, 0.5750725f, 0.29523486f, 0.7966888f, 0.19573045f, 0.95368505f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.84264994f, 0.07835853f}; + const std::vector output_dim = {2, 2, 3, 3}; + const std::vector output = { + 1.0928818f, 0.20699267f, 0.28273466f, 0.44534987f, 1.3380982f, 1.2917475f, 0.3755537f, 1.1995932f, + 1.4270226f, 0.47112367f, 0.5597638f, 0.6641071f, 0.87368786f, 1.0569134f, 0.12367705f, 0.34158573f, + 0.75108063f, 0.98591405f, 1.0929474f, 0.2070051f, 0.28275162f, 0.4451845f, 1.3376014f, 1.2912678f, + 0.37552574f, 1.1995038f, 1.4269164f, 0.47112313f, 0.5597632f, 0.6641063f, 0.87367094f, 1.056893f, + 0.12367466f, 0.34158388f, 0.7510766f, 0.98590875f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, false); +} + +TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) { + const int batch_size = 2; + const int num_heads = 2; + const int seq_len = 5; + const int head_size = 4; + const int D = 4; + const std::vector query_layer_dim = {2, 5, 8}; + const std::vector query_layer = { + 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, + 0.45562798f, 0.6323063f, 0.34889346f, 0.4017173f, 0.022325754f, 0.16885895f, 0.29388845f, 0.5185218f, + 0.6976676f, 0.8000114f, 0.16102946f, 0.28226858f, 0.68160856f, 0.915194f, 0.3970999f, 0.8741559f, + 0.41940832f, 0.55290705f, 0.9527381f, 0.03616482f, 0.18523103f, 0.37341738f, 0.30510002f, 0.9320004f, + 0.17591017f, 0.26983356f, 0.15067977f, 0.031719506f, 0.20812976f, 0.929799f, 0.7231092f, 0.7423363f, + 0.5262958f, 0.24365824f, 0.58459234f, 0.03315264f, 0.13871688f, 0.242235f, 0.81546897f, 0.7931606f, + 0.27825248f, 0.4819588f, 0.81978035f, 0.99706656f, 0.6984411f, 0.5675464f, 0.83524317f, 0.20559883f, + 0.593172f, 0.112347245f, 0.15345693f, 0.24170822f, 0.7262365f, 0.7010802f, 0.20382375f, 0.65105355f, + 0.774486f, 0.43689132f, 0.5190908f, 0.61585236f, 0.8101883f, 0.98009706f, 0.11468822f, 0.31676513f, + 0.69650495f, 0.9142747f, 0.93510365f, 0.9411784f, 0.5995073f, 0.06520867f, 0.54599625f, 0.18719733f}; + const std::vector query_bias_dim = {8}; + const std::vector query_bias = { + 0.034022927f, 0.94424623f, 0.8801799f, 0.0012360215f, 0.593586f, 0.41577f, 0.41771942f, 0.27112156f}; + const std::vector rel_pos_dim = {1, 2, 5, 5}; + const std::vector rel_pos = { + 0.6922781f, 0.20384824f, 0.68329567f, 0.75285405f, 0.8579358f, 0.6869556f, 0.005132377f, 0.17565155f, + 0.7496575f, 0.6046507f, 0.10995799f, 0.21209025f, 0.97037464f, 0.83690894f, 0.28198743f, 0.3741576f, + 0.023700953f, 0.49101293f, 0.123470545f, 0.11432165f, 0.4724502f, 0.5750725f, 0.29523486f, 0.7966888f, + 0.19573045f, 0.95368505f, 0.84264994f, 0.07835853f, 0.37555784f, 0.5225613f, 0.57295054f, 0.61858714f, + 0.69621414f, 0.5299501f, 0.25603563f, 0.7365945f, 0.02037555f, 0.20364666f, 0.37483507f, 0.25644332f, + 0.32508332f, 0.09018916f, 0.39364243f, 0.6068782f, 0.17426711f, 0.47434032f, 0.8579254f, 0.44859987f, + 0.5138961f, 0.45686555f}; + const std::vector weight_dim = {4, 4}; + const std::vector weight = { + 0.6011907f, 0.81791973f, 0.9736231f, 0.81752795f, 0.97470677f, 0.46383917f, 0.050839245f, 0.2629614f, + 0.8404526f, 0.49675876f, 0.25147682f, 0.11684412f, 0.032073975f, 0.0779959f, 0.39858162f, 0.774203f}; + const std::vector bias_dim = {4}; + const std::vector bias = { + 0.77032053f, 0.017784059f, 0.811891f, 0.10874528f}; + const std::vector eco_a_dim = {1, 2, 1, 1}; + const std::vector eco_a = { + 0.39429486f, 0.29726368f}; + const std::vector output_dim = {2, 2, 5, 5}; + const std::vector output = { + 0.9534052f, 0.28073975f, 0.9410346f, 1.0368304f, 1.181549f, 0.94923383f, 0.0070919087f, 0.24271497f, + 1.0358753f, 0.8355051f, 0.15224966f, 0.29366368f, 1.3435968f, 1.158798f, 0.3904445f, 0.5147038f, + 0.03260383f, 0.67545396f, 0.16985025f, 0.15726471f, 0.64280313f, 0.7824283f, 0.40168867f, 1.0839535f, + 0.26630563f, 1.2391479f, 1.0948771f, 0.101813294f, 0.48797214f, 0.6789776f, 0.7492329f, 0.8089107f, + 0.91042155f, 0.6930023f, 0.3348113f, 0.95611423f, 0.026447866f, 0.2643374f, 0.48654333f, 0.3328685f, + 0.4239932f, 0.117630124f, 0.5134121f, 0.7915271f, 0.22728965f, 0.61497897f, 1.1122944f, 0.5816067f, + 0.6662628f, 0.59232306f, 0.95294285f, 0.2806036f, 0.9405782f, 1.0363276f, 1.1809759f, 0.95289487f, + 0.007119261f, 0.24365108f, 1.0398705f, 0.83872753f, 0.15201466f, 0.29321042f, 1.3415229f, 1.1570094f, + 0.38984182f, 0.51978874f, 0.032925934f, 0.682127f, 0.17152825f, 0.15881838f, 0.6571103f, 0.79984313f, + 0.4106292f, 1.1080796f, 0.2722329f, 1.2398669f, 1.0955123f, 0.101872355f, 0.4882552f, 0.6793715f, + 0.7427765f, 0.8019401f, 0.9025762f, 0.6870305f, 0.33192614f, 0.9568577f, 0.026468432f, 0.26454294f, + 0.48692167f, 0.33312735f, 0.4217717f, 0.117013805f, 0.5107221f, 0.78737986f, 0.22609876f, 0.6166911f, + 1.1153911f, 0.5832259f, 0.6681177f, 0.59397215f}; + + RunGatedRelativePositionBiasTest(query_layer, query_bias, rel_pos, weight, bias, eco_a, output, + batch_size, seq_len, num_heads, head_size, D, false); +} + } // namespace test } // namespace onnxruntime From 22169b2ca5074b961bf6b94af090a50bdfbb188f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 3 Feb 2023 04:31:12 +0000 Subject: [PATCH 12/16] fix build --- .../contrib_ops/cuda/bert/relative_attn_bias.cc | 2 +- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 16 +++++++++------- .../contrib_ops/relative_attention_bias_test.cc | 1 - 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index c08bcf35d076a..111fed04639e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -158,7 +158,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c reinterpret_cast(query_tensor.template Data()), reinterpret_cast(query_bias_tensor.template Data()), reinterpret_cast(workspace.get()), - false, head_size, reinterpret_cast(nullptr), total_maxtrix); + false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix); // reuse output if possible CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast(workspace.get()) + elements_in_query) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index bdab663a9954f..c41e7059874ca 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -688,13 +688,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); int64_t num_heads = getAttribute(ctx, "num_heads", -1L); - auto& query_layer_shape = getInputShape(ctx, 0); - TensorShapeProto output_shape; - *output_shape.add_dim() = query_layer_shape.dim(0); - output_shape.add_dim()->set_dim_value(num_heads); - *output_shape.add_dim() = query_layer_shape.dim(1); - *output_shape.add_dim() = query_layer_shape.dim(1); - updateOutputShape(ctx, 0, output_shape); + if (hasInputShape(ctx, 0)) { + auto& query_layer_shape = getInputShape(ctx, 0); + TensorShapeProto output_shape; + *output_shape.add_dim() = query_layer_shape.dim(0); + output_shape.add_dim()->set_dim_value(num_heads); + *output_shape.add_dim() = query_layer_shape.dim(1); + *output_shape.add_dim() = query_layer_shape.dim(1); + updateOutputShape(ctx, 0, output_shape); + } })); } // namespace contrib diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc index 646e18605f118..21ff8d8064f52 100644 --- a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc +++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc @@ -231,7 +231,6 @@ static void RunGatedRelativePositionBiasTest( int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); - bool enable_cpu = false; if (enable_cuda) { OpTester tester("GatedRelativePositionBias", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); From 6e3070cc09c315e3e528168713583d9a46cf9a7d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 3 Feb 2023 06:48:13 +0000 Subject: [PATCH 13/16] update docs --- docs/ContribOperators.md | 53 ++++++++++++++++++++++++++++++++++++++++ docs/OperatorKernels.md | 1 + 2 files changed, 54 insertions(+) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 241ff7d718fd0..ed57ade3238c5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -29,6 +29,7 @@ Do not modify directly.* * com.microsoft.FusedConv * com.microsoft.FusedGemm * com.microsoft.FusedMatMul + * com.microsoft.GatedRelativePositionBias * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu @@ -1573,6 +1574,58 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GatedRelativePositionBias** + + query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2) + gate_u, gate_r = torch.sigmoid( + self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0 + rel_pos_bias = gate_u_1 * rel_pos + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
num_heads : int (required)
+
Number of attention heads
+
+ +#### Inputs + +
+
query_layer : T
+
tensor with shape (batch_size, seq_len, num_heads x head_size)
+
query_bias : T
+
1-d tensor with shape (num_heads x head_size)
+
rel_pos : T
+
tensor with shape (1, num_head, seq_len, seq_len)
+
weight : T
+
gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2
+
bias : T
+
bias for the gated_ur_linear, shape (D)
+
eco_a : T
+
tensor of shape (1, num_heads, 1, 1)
+
+ +#### Outputs + +
+
output : T
+
output tensor with shape (batch_size, num_heads, seq_len, seq_len)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
+ + ### **com.microsoft.GatherND** Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d1a83213e8eeb..e32214233cc78 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -802,6 +802,7 @@ Do not modify directly.* |FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| From b3b423132a62b083302c4e125e98e3b5c9765fbe Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 3 Feb 2023 07:05:06 +0000 Subject: [PATCH 14/16] exclude rocm --- cmake/onnxruntime_rocm_hipify.cmake | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 7e0bb9f6fb419..bffe75e798aa9 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -19,6 +19,10 @@ set(contrib_ops_excluded_files "bert/fast_gelu_impl.h" "bert/fast_gelu.cc" "bert/fast_gelu.h" + "bert/relative_attn_bias.cc" + "bert/relative_attn_bias.h" + "bert/relative_attn_bias_impl.cu" + "bert/relative_attn_bias_impl.h" "bert/skip_layer_norm.cc" "bert/skip_layer_norm.h" "bert/skip_layer_norm_impl.cu" From 3a27a7004d0256e5a431b2986dbf7283521c94ce Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 3 Feb 2023 08:15:19 +0000 Subject: [PATCH 15/16] fix build --- onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index dc933d7e34a05..fcf86637350b6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -663,7 +663,6 @@ Status QkvToContext( bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. - DUMP_ATTENTION_D("relative_position_bias", data.relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length); ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.relative_position_bias, scratch1, scratch2, From d4020ae864d4163a1ebf16325b01ae8cc0c30636 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 6 Feb 2023 18:40:56 +0000 Subject: [PATCH 16/16] use constexpr --- .../relative_attention_bias_test.cc | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc index 21ff8d8064f52..ba0299e4f3808 100644 --- a/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc +++ b/onnxruntime/test/contrib_ops/relative_attention_bias_test.cc @@ -268,11 +268,11 @@ static void RunGatedRelativePositionBiasTest( } TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) { - const int batch_size = 1; - const int num_heads = 2; - const int seq_len = 3; - const int head_size = 4; - const int D = 8; + constexpr int batch_size = 1; + constexpr int num_heads = 2; + constexpr int seq_len = 3; + constexpr int head_size = 4; + constexpr int D = 8; const std::vector query_layer_dim = {1, 3, 8}; const std::vector query_layer = { 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, @@ -309,11 +309,11 @@ TEST(GatedRelativePositionBiasTest, FP16_BSNHD_1x3x2x4x8) { } TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) { - const int batch_size = 2; - const int num_heads = 2; - const int seq_len = 3; - const int head_size = 4; - const int D = 8; + constexpr int batch_size = 2; + constexpr int num_heads = 2; + constexpr int seq_len = 3; + constexpr int head_size = 4; + constexpr int D = 8; const std::vector query_layer_dim = {2, 3, 8}; const std::vector query_layer = { 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f, @@ -355,11 +355,11 @@ TEST(GatedRelativePositionBiasTest, FP32_BSNHD_2x3x2x4x8) { } TEST(GatedRelativePositionBiasTest, FP32_LongSeq_BSNHD_2x5x2x4x4) { - const int batch_size = 2; - const int num_heads = 2; - const int seq_len = 5; - const int head_size = 4; - const int D = 4; + constexpr int batch_size = 2; + constexpr int num_heads = 2; + constexpr int seq_len = 5; + constexpr int head_size = 4; + constexpr int D = 4; const std::vector query_layer_dim = {2, 5, 8}; const std::vector query_layer = { 0.4962566f, 0.7682218f, 0.08847743f, 0.13203049f, 0.30742282f, 0.6340787f, 0.4900934f, 0.89644474f,