From 6d8402ec53e077b7e867dbe74b7b47ecfd0a4c40 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Jan 2023 07:06:48 +0000 Subject: [PATCH 01/27] Add benchmark --- .../transformers/fusion_attention_unet.py | 27 +- .../__init__.py | 0 .../models/stable_diffusion/benchmark.py | 244 ++++++++++++++++++ .../optimize_pipeline.py} | 85 ++++-- .../python/tools/transformers/onnx_model.py | 4 + 5 files changed, 331 insertions(+), 29 deletions(-) rename onnxruntime/python/tools/transformers/models/{diffusion => stable_diffusion}/__init__.py (100%) create mode 100755 onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py rename onnxruntime/python/tools/transformers/models/{diffusion/convert_to_fp16.py => stable_diffusion/optimize_pipeline.py} (56%) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 2151e6a21c5e7..41daad4283b4c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -103,8 +103,22 @@ def create_attention_node( is_self_attention = not self.is_cross_attention if is_self_attention: - if q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input: - logger.debug("q_matmul.input[0] != input or k_matmul.input[0] != input or q_matmul.input[0] != input") + if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input: + logger.debug( + "For self attention, input hidden state for q and k/v shall be different. Got %s, %s, %s", + q_matmul.input[0], + k_matmul.input[0], + v_matmul.input[0], + ) + return None + else: + if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input): + logger.debug( + "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s", + q_matmul.input[0], + k_matmul.input[0], + v_matmul.input[0], + ) return None if hidden_size > 0 and (hidden_size % num_heads) != 0: @@ -203,9 +217,12 @@ def create_attention_node( return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - node_before_layernorm = self.model.match_parent( - normalize_node, "Add" if self.is_cross_attention else "Reshape", 0 - ) + node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) + + # In SD 1.5, for self attention, LayerNorm has parent Reshape + if node_before_layernorm is None and not self.is_cross_attention: + node_before_layernorm = self.model.match_parent(normalize_node, "Reshape", 0) + if node_before_layernorm is None: return diff --git a/onnxruntime/python/tools/transformers/models/diffusion/__init__.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/__init__.py similarity index 100% rename from onnxruntime/python/tools/transformers/models/diffusion/__init__.py rename to onnxruntime/python/tools/transformers/models/stable_diffusion/__init__.py diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py new file mode 100755 index 0000000000000..580c5ef4c3cca --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -0,0 +1,244 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import argparse +import os +import time + +SD_MODELS = { + "1.5": "runwayml/stable-diffusion-v1-5", + "2.0": "stabilityai/stable-diffusion-2", + "2.1": "stabilityai/stable-diffusion-2-1", +} + + +def get_test_settings(): + height = 512 + width = 512 + num_inference_steps = 50 + prompts = [ + "a photo of an astronaut riding a horse on mars", + "cute grey cat with blue eyes, wearing a bowtie, acrylic painting", + "a cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital painting", + "an illustration of a house with large barn with many cute flower pots and beautiful blue sky scenery", + "one apple sitting on a table, still life, reflective, full color photograph, centered, close-up product", + "background texture of stones, masterpiece, artistic, stunning photo, award winner photo", + "new international organic style house, tropical surroundings, architecture, 8k, hdr", + "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation", + "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic", + "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k", + ] + + return height, width, num_inference_steps, prompts + + +def get_ort_pipeline(model_name: str, directory: str, provider: str, disable_safety_checker: bool): + from diffusers import OnnxStableDiffusionPipeline + + import onnxruntime + + if directory is not None: + assert os.path.exists(directory) + session_options = onnxruntime.SessionOptions() + pipe = OnnxStableDiffusionPipeline.from_pretrained( + directory, + provider=provider, + sess_options=session_options, + ) + else: + pipe = OnnxStableDiffusionPipeline.from_pretrained( + model_name, + revision="onnx", + provider=provider, + use_auth_token=True, + ) + + if disable_safety_checker: + pipe.safety_checker = None + pipe.feature_extractor = None + + return pipe + + +def get_torch_pipeline(model_name: str, disable_channels_last: bool, disable_safety_checker: bool): + from diffusers import StableDiffusionPipeline + from torch import channels_last, float16 + + pipe = StableDiffusionPipeline.from_pretrained( + model_name, torch_dtype=float16, revision="fp16", use_auth_token=True + ).to("cuda") + + if not disable_channels_last: + pipe.unet.to(memory_format=channels_last) # in-place operation + + if disable_safety_checker: + pipe.safety_checker = None + pipe.feature_extractor = None + + return pipe + + +def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, disable_safety_checker: bool): + short_model_name = model_name.split("/")[-1].replace("stable-diffusion-", "sd") + return f"{engine}_{short_model_name}_b{batch_size}" + ("" if disable_safety_checker else "_safe") + + +def run_ort_pipeline(pipe, batch_size: int, image_filename_prefix: str): + from diffusers import OnnxStableDiffusionPipeline + + assert isinstance(pipe, OnnxStableDiffusionPipeline) + + height, width, num_inference_steps, prompts = get_test_settings() + + pipe("warm up", height, width, num_inference_steps=2) + + latency_list = [] + for i, prompt in enumerate(prompts): + input_prompts = [prompt] * batch_size + inference_start = time.time() + image = pipe(input_prompts, height, width, num_inference_steps).images[0] + inference_end = time.time() + + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency} seconds") + image.save(f"{image_filename_prefix}_{i}.jpg") + print("Average latency in seconds:", sum(latency_list) / len(latency_list)) + + +def run_torch_pipeline(pipe, batch_size: int, image_filename_prefix: str): + import torch + + height, width, num_inference_steps, prompts = get_test_settings() + + pipe("warm up", height, width, num_inference_steps=2) + + torch.set_grad_enabled(False) + + latency_list = [] + for i, prompt in enumerate(prompts): + input_prompts = [prompt] * batch_size + torch.cuda.synchronize() + inference_start = time.time() + image = pipe(input_prompts, height, width, num_inference_steps).images[0] + torch.cuda.synchronize() + inference_end = time.time() + + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency} seconds") + image.save(f"{image_filename_prefix}_{i}.jpg") + + print("Average latency in seconds:", sum(latency_list) / len(latency_list)) + + +def run_ort(model_name: str, directory: str, provider: str, batch_size: int, disable_safety_checker: bool): + load_start = time.time() + pipe = get_ort_pipeline(model_name, directory, provider, disable_safety_checker) + load_end = time.time() + print(f"Model loading took {load_end - load_start} seconds") + + image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, disable_safety_checker) + run_ort_pipeline(pipe, batch_size, image_filename_prefix) + + +def run_torch(model_name: str, batch_size: int, disable_channels_last: bool, disable_safety_checker: bool): + import torch + + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + # torch.backends.cuda.matmul.allow_tf32 = True + + torch.set_grad_enabled(False) + + load_start = time.time() + pipe = get_torch_pipeline(model_name, disable_channels_last, disable_safety_checker) + load_end = time.time() + print(f"Model loading took {load_end - load_start} seconds") + + image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, disable_safety_checker) + ( + "" if disable_channels_last else "_channels_last" + ) + with torch.inference_mode(): + run_torch_pipeline(pipe, batch_size, image_filename_prefix) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-e", + "--engine", + required=False, + type=str, + default="onnxruntime", + choices=["onnxruntime", "torch"], + help="Engines to benchmark. Default is onnxruntime.", + ) + + parser.add_argument( + "-v", + "--version", + required=True, + type=str, + choices=list(SD_MODELS.keys()), + help="Stable diffusion version like 1.5, 2.0 or 2.1", + ) + + parser.add_argument( + "-p", + "--pipeline", + required=False, + type=str, + default=None, + help="Directory of saved onnx pipeline. It could be output directory of optimize_pipeline.py.", + ) + + parser.add_argument( + "-c", + "--disable_channels_last", + required=False, + action="store_true", + help="Disable channels last for torch. It will be ignored for onnxruntime engine", + ) + parser.set_defaults(disable_channels_last=False) + + parser.add_argument( + "--enable_safety_checker", + required=False, + action="store_true", + help="Enable safety checker", + ) + parser.set_defaults(enable_safety_checker=False) + + parser.add_argument("-b", "--batch_size", type=int, default=1) + + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + print(args) + + sd_model = SD_MODELS[args.version] + if args.engine == "onnxruntime": + assert args.pipeline, "--pipeline should be specified for onnxruntime engine" + + if args.batch_size > 1: + # Need remove a line https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L307 + # in diffuers to run batch_size > 1. + assert ( + args.enable_safety_checker + ), "batch_size > 1 is not compatible with safety checker due to a bug in diffuers" + + provider = "CUDAExecutionProvider" # TODO: use ["CUDAExecutionProvider", "CPUExecutionProvider"] in diffuers + run_ort(sd_model, args.pipeline, provider, args.batch_size, not args.enable_safety_checker) + else: + run_torch(sd_model, args.batch_size, args.disable_channels_last, not args.enable_safety_checker) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py similarity index 56% rename from onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py rename to onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 8e20f58dd75d4..dd38103e995b1 100644 --- a/onnxruntime/python/tools/transformers/models/diffusion/convert_to_fp16.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -10,12 +10,24 @@ # cd diffusers # pip install -e . # huggingface-cli login -# python3 scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5 +# python3 scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5-fp32 +# Or use diffusers packages: +# export ONNX_ROOT=./sd_onnx +# pip install diffusers==0.11.1 transformers==4.21.2 +# huggingface-cli login +# wget https://raw.githubusercontent.com/huggingface/diffusers/v0.11.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/stable-diffusion-v1-5-fp32 +# python convert_stable_diffusion_checkpoint_to_onnx.py --model_path stabilityai/stable-diffusion-2-1 --output_path $ONNX_ROOT/stable-diffusion-v2-1-fp32 # # Then you can use this script to convert them to float16 like the following: -# pip3 install -U onnxruntime-gpu >= 1.14 -# python3 -m onnxruntime.transformers.models.diffusion.convert_to_fp16 -i ../stable-diffusion-v1-5 -o ../stable-diffusion-v1-5-fp16 -# Note that float16 model is intended for CUDA Execution Provider. It might not run in CPU Execution Provider. +# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 +# python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16 +# Or +# pip install -U onnxruntime-gpu >= 1.14 +# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 +# python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16 + +# Note that float16 model is for CUDA Execution Provider. It might not run in CPU Execution Provider. import argparse import logging @@ -32,14 +44,17 @@ logger = logging.getLogger(__name__) -def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool): - """Convert a model to float16 +def optimize_stable_diffusion_onnx_pipeline( + source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool +): + """Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16. Args: - source_dir (Path): source directory - target_dir (Path): target directory - overwrite (bool): overwrite if exists - use_external_data_format (bool): save model to two files: one for onnx graph, another for weights + source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. + target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. + overwrite (bool): Overwrite files if exists. + use_external_data_format (bool): save onnx model to two files: one for onnx graph, another for weights + float16 (bool): use half precision Raises: RuntimeError: input onnx model does not exist @@ -50,13 +65,17 @@ def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_ext onnx_model_path = source_dir / name / "model.onnx" if not os.path.exists(onnx_model_path): - raise RuntimeError(f"input onnx model does not exist: {onnx_model_path}") + message = f"input onnx model does not exist: {onnx_model_path}." + if name not in ["safety_checker", "feature_extractor"]: + raise RuntimeError(message) + continue num_heads = 0 hidden_size = 0 # Graph fusion before fp16 conversion, otherwise they cannot be fused later. # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. + logger.info(f"optimize {onnx_model_path}...") m = optimize_model( str(onnx_model_path), model_type="unet", @@ -67,11 +86,14 @@ def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_ext use_gpu=False, ) - # VAE-decoder in fp16 reduced quality thus we exclude it here - if name != "vae_decoder": - m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize"]) - else: - print("skip convert vae_decoder to fp16.") + if float16: + # VAE-decoder in fp16 reduced quality thus we exclude it here + # TODO: enable mixed precision conversion for VAE-decoder. + if name != "vae_decoder": + logger.info(f"convert to float16 ...") + m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize"]) + else: + logger.info("skip convert vae_decoder to fp16.") optimized_model_path = target_dir / name / "model.onnx" output_dir = optimized_model_path.parent @@ -87,8 +109,8 @@ def convert_to_fp16(source_dir: Path, target_dir: Path, overwrite: bool, use_ext print(f"{onnx_model_path} => {optimized_model_path}") -def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool): - """Copy extra directory. +def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): + """Copy extra directory that does not have onnx model Args: source_dir (Path): source directory @@ -102,8 +124,12 @@ def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool): extra_dirs = ["scheduler", "tokenizer", "feature_extractor"] for name in extra_dirs: source_path = source_dir / name + if not os.path.exists(source_path): - raise RuntimeError(f"source path does not exist: {source_path}") + message = f"source path does not exist: {source_path}" + if name not in ["safety_checker", "feature_extractor"]: + raise RuntimeError(message) + continue target_path = target_dir / name if target_path.exists(): @@ -126,7 +152,7 @@ def copy_extra(source_dir: Path, target_dir: Path, overwrite: bool): raise RuntimeError(f"output path existed: {target_path}") os.remove(target_path) shutil.copyfile(source_path, target_path) - print(f"{source_path} => {target_path}") + logger.info(f"{source_path} => {target_path}") def parse_arguments(): @@ -150,9 +176,17 @@ def parse_arguments(): "--output", required=True, type=str, - help="Root of output directory of stable diffusion onnx pipeline with float16 models.", + help="Root of output directory of stable diffusion onnx pipeline with optimized models.", ) + parser.add_argument( + "--float16", + required=False, + action="store_true", + help="Output models of half or mixed precision.", + ) + parser.set_defaults(float16=False) + parser.add_argument( "--overwrite", required=False, @@ -166,7 +200,8 @@ def parse_arguments(): "--use_external_data_format", required=False, action="store_true", - help="Onnx model larger than 2GB need to use external data format.", + help="Onnx model larger than 2GB need to use external data format. " + "Save onnx model to two files: one for onnx graph, another for large weights.", ) parser.set_defaults(use_external_data_format=False) @@ -177,8 +212,10 @@ def parse_arguments(): def main(): coloredlogs.install(fmt="%(funcName)20s: %(message)s") args = parse_arguments() - copy_extra(Path(args.input), Path(args.output), args.overwrite) - convert_to_fp16(Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format) + copy_extra_directory(Path(args.input), Path(args.output), args.overwrite) + optimize_stable_diffusion_onnx_pipeline( + Path(args.input), Path(args.output), args.overwrite, args.use_external_data_format, args.float16 + ) main() diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 4827facd78100..cfca8fb0578b5 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -977,6 +977,10 @@ def save_model_to_file(self, output_path, use_external_data_format=False, all_te logger.info("Sort graphs in topological order") self.topological_sort() + # Note: After the model is saved to another directory with external data, + # You need reload the onnx model if you want to read tensor from self.model object. + # It is because the base directory is not updated for self.model object so attemp to read tensor data + # might encounter error since external data cannot be located. OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file) logger.info(f"Model saved to {output_path}") From 0ac952e9f07c7428dafcebe3386cd141646520da Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Jan 2023 22:12:04 +0000 Subject: [PATCH 02/27] Add GroupNorm fusion --- .../python/tools/symbolic_shape_infer.py | 5 + .../tools/transformers/fusion_group_norm.py | 152 ++++++++++++++++++ .../tools/transformers/fusion_options.py | 11 ++ .../stable_diffusion/optimize_pipeline.py | 8 +- .../tools/transformers/onnx_model_unet.py | 5 + 5 files changed, 177 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/fusion_group_norm.py diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ed94a01f562ef..ee8b7923ca62d 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -200,6 +200,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PythonOp": self._infer_PythonOp, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + "GroupNorm": self._infer_GroupNorm, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -434,6 +435,7 @@ def _onnx_infer_single_node(self, node): "SkipLayerNormalization", "PythonOp", "MultiHeadAttention", + "GroupNorm", ] if not skip_infer: @@ -2056,6 +2058,9 @@ def _infer_SkipLayerNormalization(self, node): if len(node.output) > 3: self._propagate_shape_and_type(node, 0, 3) + def _infer_GroupNorm(self, node): + self._propagate_shape_and_type(node) + def _infer_PythonOp(self, node): output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py new file mode 100644 index 0000000000000..24dc518e6e894 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -0,0 +1,152 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Dict + +import numpy as np +from fusion_base import Fusion +from onnx import TensorProto, helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionGroupNorm(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "GroupNorm", "Add") + + def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + Fuse Group Normalization subgraph into one node GroupNorm. + The following is the pattern with swish activation: + +----------------Shape-------------------------------+ + | | + | (0, 32, -1) v (512x1x1) (512x1x1) + [Root] --> Reshape -------> InstanceNormalization --> Reshape ---> Mul --> Add --> Mul--> [output] + Bx512xHxW (scale=ones(32), B=zeros(32)) | ^ Bx512xHxW + | | + +--->Sigmoid + The following is the pattern without swish activation: + +----------------Shape-------------------------------+ + | | + | (0, 32, -1) v (512x1x1) (512x1x1) + [Root] --> Reshape -------> InstanceNormalization --> Reshape ---> Mul --> Add -->[output] + Bx512xHxW (scale=ones(32), B=zeros(32)) Bx512xHxW + """ + nodes = self.model.match_parent_path( + add_node, ["Mul", "Reshape", "InstanceNormalization", "Reshape"], [0, 0, 0, 0], output_name_to_node + ) + if nodes is None: + return + + weight_mul, reshape_4d, instance_norm, reshape_3d = nodes + root = reshape_3d.input[0] + + parents = self.model.match_parent_path(reshape_4d, ["Shape"], [1], output_name_to_node) + if parents is None: + return + if parents[0].input[0] != root: + return + shape_node = parents[0] + + # Check whether it has swish activation. + swish_mul = self.model.find_first_child_by_type(add_node, "Mul") + swish_sigmoid = None + if swish_mul is not None: + sigmoid_path = self.model.match_parent_path(swish_mul, ["Sigmoid"], [None], output_name_to_node) + if sigmoid_path is not None: + swish_sigmoid = sigmoid_path[0] + + weight_input = weight_mul.input[1 - self.model.input_index(reshape_4d.output[0], weight_mul)] + if not self.model.is_constant_with_specified_dimension(weight_input, 3, "group norm weight"): + return + + bias_input = add_node.input[1 - self.model.input_index(weight_mul.output[0], add_node)] + if not self.model.is_constant_with_specified_dimension(bias_input, 3, "layernorm bias"): + return + + weight = self.model.get_constant_value(weight_input) + if not (len(weight.shape) == 3 and weight.shape[1] == 1 and weight.shape[2] == 1): + return + + bias = self.model.get_constant_value(bias_input) + if not (len(bias.shape) == 3 and bias.shape[1] == 1 and bias.shape[2] == 1): + return + + weight_elements = np.prod(weight.shape) + bias_elements = np.prod(bias.shape) + if weight_elements != bias_elements: + return + + instance_norm_scale = self.model.get_constant_value(instance_norm.input[1]) + if instance_norm_scale is None: + return + instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) + if instance_norm_bias is None: + return + + if not ( + len(instance_norm_scale.shape) == 1 + and len(instance_norm_bias.shape) == 1 + and instance_norm_scale.shape == instance_norm_bias.shape + and instance_norm_scale.shape[0] == 32 + ): + return + + if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): + return + if not np.allclose(np.zeros_like(instance_norm_bias), instance_norm_bias): + return + + group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") + + gamma = helper.make_tensor( + name=group_norm_name + "_gamma", + data_type=TensorProto.FLOAT, + dims=[weight_elements], + vals=weight.flatten().tolist(), + ) + self.model.add_initializer(gamma, self.this_graph_name) + + beta = helper.make_tensor( + name=group_norm_name + "_beta", + data_type=TensorProto.FLOAT, + dims=[bias_elements], + vals=bias.flatten().tolist(), + ) + self.model.add_initializer(beta, self.this_graph_name) + + last_node = add_node + subgraph_nodes = [add_node, weight_mul, reshape_4d, instance_norm, reshape_3d, shape_node] + has_swish = swish_mul and swish_sigmoid + if swish_mul and swish_sigmoid: + subgraph_nodes.extend([swish_mul, swish_sigmoid]) + last_node = swish_mul + + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + last_node.output, + input_name_to_nodes, + output_name_to_node, + ): + self.node_to_remove.extend([last_node]) + else: + self.nodes_to_remove.extend(subgraph_nodes) + + # instance_norm_scale might from Constant node. Use prune graph to clear it. + self.prune_graph = True + + new_node = helper.make_node( + "GroupNorm", + inputs=[root, group_norm_name + "_gamma", group_norm_name + "_beta"], + outputs=[last_node.output[0]], + ) + + new_node.attribute.extend(instance_norm.attribute) + new_node.attribute.extend([helper.make_attribute("groups", 32)]) + new_node.attribute.extend([helper.make_attribute("swish", 1 if has_swish else 0)]) + new_node.domain = "com.microsoft" + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 9a5359b58caa6..f3104d86130c8 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -36,6 +36,7 @@ def __init__(self, model_type): self.enable_shape_inference = True self.enable_gemm_fast_gelu = False + self.enable_group_norm = model_type == "unet" self.attention_mask_format = AttentionMaskFormat.AttentionMask def use_raw_attention_mask(self, use_raw_mask=True): @@ -76,6 +77,8 @@ def parse(args): options.use_raw_attention_mask(False) if args.no_attention_mask: options.disable_attention_mask() + if args.enable_group_norm: + options.enable_group_norm = True return options @staticmethod @@ -185,3 +188,11 @@ def add_arguments(parser: ArgumentParser): "MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.", ) parser.set_defaults(use_multi_head_attention=False) + + parser.add_argument( + "--enable_group_norm", + required=False, + action="store_true", + help="fuse GroupNorm. Only works for model_type=unet", + ) + parser.set_defaults(enable_group_norm=False) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index dd38103e995b1..e27cf5f8f1391 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -10,7 +10,7 @@ # cd diffusers # pip install -e . # huggingface-cli login -# python3 scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5-fp32 +# python scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5-fp32 # Or use diffusers packages: # export ONNX_ROOT=./sd_onnx # pip install diffusers==0.11.1 transformers==4.21.2 @@ -60,7 +60,7 @@ def optimize_stable_diffusion_onnx_pipeline( RuntimeError: input onnx model does not exist RuntimeError: output onnx model path existed """ - dirs_with_onnx = ["vae_encoder", "vae_decoder", "text_encoder", "safety_checker", "unet"] + dirs_with_onnx = ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"] for name in dirs_with_onnx: onnx_model_path = source_dir / name / "model.onnx" @@ -106,7 +106,7 @@ def optimize_stable_diffusion_onnx_pipeline( output_dir.mkdir(parents=True, exist_ok=True) m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format) - print(f"{onnx_model_path} => {optimized_model_path}") + logger.info(f"{onnx_model_path} => {optimized_model_path}") def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): @@ -138,7 +138,7 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): shutil.rmtree(target_path) shutil.copytree(source_path, target_path) - print(f"{source_path} => {target_path}") + logger.info(f"{source_path} => {target_path}") extra_files = ["model_index.json"] for name in extra_files: diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 7872cf68e7366..b59db5087df31 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -7,6 +7,7 @@ from typing import Optional from fusion_attention_unet import FusionAttentionUnet +from fusion_group_norm import FusionGroupNorm from fusion_options import FusionOptions from onnx import ModelProto from onnx_model_bert import BertOnnxModel @@ -52,6 +53,10 @@ def optimize(self, options: Optional[FusionOptions] = None): self.fuse_reshape() + if (options is None) or options.enable_group_norm: + group_norm_fusion = FusionGroupNorm(self) + group_norm_fusion.apply() + if (options is None) or options.enable_attention: self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False) self_attention_fusion.apply() From 29c47dc3ce060c655933dd958454ac2e2a682ed0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Jan 2023 22:04:40 +0000 Subject: [PATCH 03/27] [CUDA] Add GroupNormalization operator --- cmake/onnxruntime_rocm_hipify.cmake | 4 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../contrib_ops/cuda/diffusion/group_norm.cc | 106 +++++ .../contrib_ops/cuda/diffusion/group_norm.h | 28 ++ .../cuda/diffusion/group_norm_impl.cu | 416 ++++++++++++++++++ .../cuda/diffusion/group_norm_impl.h | 43 ++ .../core/graph/contrib_ops/diffusion_defs.cc | 67 +++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + 8 files changed, 668 insertions(+) create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm.h create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h create mode 100644 onnxruntime/core/graph/contrib_ops/diffusion_defs.cc diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index d3b8f5ebfcc26..376779cc01179 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -27,6 +27,10 @@ set(contrib_ops_excluded_files "bert/tensorrt_fused_multihead_attention/*" "bert/transformer_common.h" "bert/transformer_common.cc" + "diffusion/group_norm.h" + "diffusion/group_norm.cc" + "diffusion/group_norm_impl.cu" + "diffusion/group_norm_impl.h" "math/complex_mul.cc" "math/complex_mul.h" "math/complex_mul_impl.cu" diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 38bcbc298b939..e054b44f0d53f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -71,6 +71,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); @@ -192,6 +193,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc new file mode 100644 index 0000000000000..56bf68285bb65 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/group_norm.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupNorm, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + GroupNorm); + +REGISTER_KERNEL_TYPED(MLFloat16); + +using namespace ONNX_NAMESPACE; + +template +GroupNorm::GroupNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); + ORT_ENFORCE(epsilon_ >= 0); + + int64_t num_groups; + ORT_ENFORCE(op_kernel_info.GetAttr("groups", &num_groups).IsOK()); + ORT_ENFORCE(num_groups >= 0); + num_groups_ = static_cast(num_groups); + + + ORT_ENFORCE(op_kernel_info.GetAttr("swish", &swish_).IsOK()); +} + +template +Status GroupNorm::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* gamma = context->Input(1); + const Tensor* beta = context->Input(2); + Tensor* output = context->Output(0, input->Shape()); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 4 dimensions, got ", input_dims.size()); + } + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of gamma and input does not match"); + } + + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of beta and input does not match"); + } + + int batch_size = static_cast(input_dims[0]); + int num_channels = static_cast(input_dims[1]); + int height = static_cast(input_dims[2]); + int width = static_cast(input_dims[3]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + + typedef typename ToCudaType::MappedType CudaT; + + return LaunchGroupNormKernel( + Stream(context), + reinterpret_cast(output->MutableData()), + reinterpret_cast(input->Data()), + reinterpret_cast(gamma->Data()), + reinterpret_cast(beta->Data()), + reinterpret_cast(workspace.get()), + epsilon_, + batch_size, + num_channels, + height, + width, + num_groups_, + swish_); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h new file mode 100644 index 0000000000000..85341a60a5e87 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class GroupNorm final : public CudaKernel { + public: + GroupNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + bool swish_; + float epsilon_; + int num_groups_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu new file mode 100644 index 0000000000000..dc3b631c58585 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -0,0 +1,416 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +static inline int32_t divUp(int32_t m, int32_t n) { + return (m + n - 1) / n; +} + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sumSq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +struct GroupNormNHWCParams { + // The output buffer. Layout NHWC. + __half* dst; + // The input buffer. Layout NHWC. + __half const* src; + // The gamma scaling factor. + float const* gamma; + // The beta term to add in GN. + float const* beta; + // The temporary buffer to do the global parallel reduction. Size: + // BLOCKS_PER_BATCH x C x 2. + float* redBuffer; + + // The number of instances in the batch. + int32_t n; + // The height and width of each activation map. + int32_t h; + int32_t w; + // The number of channels. + int32_t c; + // The number of groups. + int32_t groups; + // Do we apply the Swish activation function? + bool withSwish; + + // Precomputed values and parameters to control the execution of the kernels. + + // The number of activations per instance (h * w) and the number of + // activations per block. + int32_t hw; + int32_t hwPerBlock; + // The number of channels per group and blocks per activation in the C + // dimension. + int32_t cPerBlock; + int32_t cPerGroup; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hwc in floats (to compute mean/var). + float invHWC; + // The precomputed number of groups per block. + int32_t groupsPerBlock; +}; + +template +__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage tempStorage; + // Allocate shared memory for the groups. We could reduce the amount of shared + // memory reserved. + __shared__ float2 smem[tTHREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; + + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + h2 = *reinterpret_cast<__half2 const*>(¶ms.src[offset]); + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; + } + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = threadIdx.x * 2 / params.cPerGroup; + int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + + // Do the segmented scan. + GroupSums out; + BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced + // stores later). + if (cj == params.cPerGroup - 2 /* 2 channels per thread */) { + smem[gi] = make_float2(out.sum, out.sumSq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The global group index. + int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + return; + } + + // The first threads (those storing to global memory, load the values). + float2 sums = smem[threadIdx.x]; + + // Store to global memory. + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); +} + +void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { + // Make sure the values are as we expect. + ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); + // Make sure a group does not span multiple blocks. + ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); + + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNHWCSumKernel<160><<>>(params); + break; + case 480: + groupNormNHWCSumKernel<256><<>>(params); + break; + case 256: + groupNormNHWCSumKernel<128><<>>(params); + break; + case 128: + groupNormNHWCSumKernel<64><<>>(params); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template +__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + // The group that thread works on and the channel in the group (modulus). + int32_t gi = ci / params.cPerGroup; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sumSq = 0.F; + if (gi < params.groups) { + sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; + sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + } + + // Load gamma/beta. + float2 gammaF2, betaF2; + if (ci < params.c) { + gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); + betaF2 = *reinterpret_cast(¶ms.beta[ci]); + } + + // Compute the mean. + float mean = sum * params.invHWC; + // Compute the variance. + float var = sumSq * params.invHWC - (mean * mean); + // Compute the inverse of the stddev. + float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + h2 = *reinterpret_cast<__half2 const*>(¶ms.src[offset]); + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (params.withSwish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + // Store the scaled values. + if (ci < params.c) { + *reinterpret_cast<__half2*>(¶ms.dst[offset]) = __float22half2_rn(f2); + } + } +} + +void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { + // Make sure the dimensions are aligned with what we expect. + ORT_ENFORCE(params.c % params.cPerBlock == 0); + // Make sure a group does not span multiple blocks. + ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); + + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + groupNormNHWCScaleKernel<160><<>>(params); + break; + case 480: + groupNormNHWCScaleKernel<256><<>>(params); + break; + case 256: + groupNormNHWCScaleKernel<128><<>>(params); + break; + case 128: + groupNormNHWCScaleKernel<64><<>>(params); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { + int32_t maxDivisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { + maxDivisor = divisor1; + } + if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { + maxDivisor = divisor2; + } + } + } + return maxDivisor; +} + +template +Status LaunchGroupNormKernel( + cudaStream_t stream, + T* output, + const T* input, + const float* gamma, + const float* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool swish) { + if (batch_size > static_cast(kMaxGroupNormBatchSize)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "only support batch_size <= 32. Got", batch_size); + } + + if (num_groups != static_cast(kGroupNormNumberOfGroups)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "only num_groups=32 is supported. Got", num_groups); + } + + GroupNormNHWCParams params; + int32_t cPerBlock = 320; + int32_t maxBlocksPerHW = 1024; + switch (num_channels) { + case 960: + case 1920: + cPerBlock = 480; + break; + case 512: + case 256: + cPerBlock = 256; + break; + case 128: + cPerBlock = 128; + break; + default: + cPerBlock = 320; + } + + params.withSwish = bool(swish); + params.dst = static_cast(output); + params.src = static_cast(input); + params.gamma = static_cast(gamma); + params.beta = static_cast(beta); + params.redBuffer = static_cast(workspace); + params.n = batch_size; + params.h = height; + params.w = width; + params.c = num_channels; + params.groups = num_groups; + params.hw = params.h * params.w; + const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); + params.hwPerBlock = divUp(params.hw, blocksPerHW); + params.cPerBlock = cPerBlock; + params.cPerGroup = params.c / params.groups; + params.hwc = params.hw * params.c; + params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); + params.groupsPerBlock = cPerBlock / params.cPerGroup; + + cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); + groupNormNHWCSum(params, stream); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + groupNormNHWCScale(params, stream); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, + const half* input, const float* gamma, const float* beta, void* workspace, + float epsilon, int batch_size, int num_channels, + int height, int width, int num_groups, bool swish); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h new file mode 100644 index 0000000000000..5595d9ffb5082 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/common/status.h" +#include +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +constexpr size_t kMaxGroupNormBatchSize = 32; +constexpr size_t kGroupNormNumberOfGroups = 32; + +constexpr size_t GetGroupNormWorkspaceSizeInBytes() +{ + // Two buffers for sum and squared sum + return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; +} + +template +Status LaunchGroupNormKernel( + cudaStream_t stream, + T* output, // normalized output tensor + const T* input, // input tensor + const float* gamma, // gamma (also known as weight or scale) + const float* beta, // beta (also known as bias) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool swish // Whether there is Swish activation after group normalization + ); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc new file mode 100644 index 0000000000000..9a53c59e9043e --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/constants.h" +#include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/contrib_ops/onnx_function_util.h" +#include "core/graph/contrib_ops/shape_inference_functions.h" + +// Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from +// ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build +#if defined(_WIN32) && !defined(NDEBUG) +#pragma warning(disable : 26426) +#endif + +namespace onnxruntime { +namespace contrib { +using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::OpSchema; +#ifndef NDEBUG +using ONNX_NAMESPACE::DbgOperatorSetTracker; +#endif + +constexpr const char* GroupNorm_ver1_doc = R"DOC( +Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494). + +This operator transforms input according to + y = gamma * (x - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The swish attribute can be used to enable Swish activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GroupNorm, 1, + OpSchema() + .SetDoc(GroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("swish", + "Whether use Swish activation after group normalization", + AttributeProto::INT) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "T") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "T") + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input X and output Y types to half tensors.") + .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 1f0af31a4bdd0..d82da5e65db71 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -69,6 +69,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupNorm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite); @@ -155,6 +156,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); From 5f40aa8cc86daae4b6e4357f6905c73dfc58d8f9 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Jan 2023 22:59:23 +0000 Subject: [PATCH 04/27] Add Cast for fp16 group_norm --- .../tools/transformers/fusion_group_norm.py | 40 +++++++++++++------ .../python/tools/transformers/fusion_utils.py | 21 ++++++++-- .../stable_diffusion/optimize_pipeline.py | 2 +- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index 24dc518e6e894..ea94baa690e34 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -7,6 +7,7 @@ import numpy as np from fusion_base import Fusion +from fusion_utils import FusionUtils from onnx import TensorProto, helper from onnx_model import OnnxModel @@ -23,17 +24,12 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): The following is the pattern with swish activation: +----------------Shape-------------------------------+ | | - | (0, 32, -1) v (512x1x1) (512x1x1) + | (0, 32, -1) v (512x1x1) (512x1x1) (optional) [Root] --> Reshape -------> InstanceNormalization --> Reshape ---> Mul --> Add --> Mul--> [output] Bx512xHxW (scale=ones(32), B=zeros(32)) | ^ Bx512xHxW | | - +--->Sigmoid - The following is the pattern without swish activation: - +----------------Shape-------------------------------+ - | | - | (0, 32, -1) v (512x1x1) (512x1x1) - [Root] --> Reshape -------> InstanceNormalization --> Reshape ---> Mul --> Add -->[output] - Bx512xHxW (scale=ones(32), B=zeros(32)) Bx512xHxW + +--->Sigmoid (optional) + The Mul and Sigmoid before output is for Swish activation. They are optional. """ nodes = self.model.match_parent_path( add_node, ["Mul", "Reshape", "InstanceNormalization", "Reshape"], [0, 0, 0, 0], output_name_to_node @@ -93,6 +89,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): and instance_norm_scale.shape == instance_norm_bias.shape and instance_norm_scale.shape[0] == 32 ): + logger.info(f"InstanceNormalization groups={instance_norm_scale.shape[0]}") return if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): @@ -102,6 +99,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") + logger.info(f"GroupNorm channels={weight_elements}") + gamma = helper.make_tensor( name=group_norm_name + "_gamma", data_type=TensorProto.FLOAT, @@ -120,7 +119,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): last_node = add_node subgraph_nodes = [add_node, weight_mul, reshape_4d, instance_norm, reshape_3d, shape_node] - has_swish = swish_mul and swish_sigmoid + has_swish_activation = swish_mul and swish_sigmoid if swish_mul and swish_sigmoid: subgraph_nodes.extend([swish_mul, swish_sigmoid]) last_node = swish_mul @@ -138,15 +137,32 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): # instance_norm_scale might from Constant node. Use prune graph to clear it. self.prune_graph = True + # Right now GroupNorm only support float16 input. Need add a Cast in fp32 model. + utils = FusionUtils(self.model) + + input = root + output = last_node.output[0] + if weight.dtype == np.float32: + # Add a Cast node to get float16 input for GroupNorm + cast_input, _cast_node = utils.cast_input(root, "float16") + input = cast_input + + # Add a Cast node to convert back to float32 after GroupNorm + output = group_norm_name + "_out" + cast_node = helper.make_node("Cast", inputs=[group_norm_name + "_out"], outputs=[last_node.output[0]]) + cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.FLOAT))]) + self.model.add_node(cast_node) + new_node = helper.make_node( "GroupNorm", - inputs=[root, group_norm_name + "_gamma", group_norm_name + "_beta"], - outputs=[last_node.output[0]], + inputs=[input, group_norm_name + "_gamma", group_norm_name + "_beta"], + outputs=[output], + name=group_norm_name, ) new_node.attribute.extend(instance_norm.attribute) new_node.attribute.extend([helper.make_attribute("groups", 32)]) - new_node.attribute.extend([helper.make_attribute("swish", 1 if has_swish else 0)]) + new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)]) new_node.domain = "com.microsoft" self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index 865c1542c1cc9..b5e390d835b18 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -28,8 +28,8 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}") return False, input_name - def cast_input_to_int32(self, input_name: str): - cast_output = input_name + "_int32" + def cast_input(self, input_name: str, target_type = "int32"): + cast_output = input_name + "_" + target_type # Avoid consequent Cast nodes. inputs = [input_name] @@ -40,11 +40,26 @@ def cast_input_to_int32(self, input_name: str): inputs = [parent_node.input[0]] cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) - cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.INT32))]) + + to_type = -1 + if target_type == "int32": + to_type = int(TensorProto.INT32) + elif target_type == "float32": + to_type = int(TensorProto.FLOAT) + elif target_type == "float16": + to_type = int(TensorProto.FLOAT16) + else: + raise ValueError("Invalid target_type: {target_type}") + + cast_node.attribute.extend([helper.make_attribute("to", to_type)]) self.model.add_node(cast_node) return cast_output, cast_node + + def cast_input_to_int32(self, input_name: str): + return self.cast_input(input_name, "int32") + def remove_cast_int32(self, input_name: str): input_name_to_nodes = self.model.input_name_to_nodes() nodes = input_name_to_nodes[input_name] diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index e27cf5f8f1391..77de06c509e12 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -91,7 +91,7 @@ def optimize_stable_diffusion_onnx_pipeline( # TODO: enable mixed precision conversion for VAE-decoder. if name != "vae_decoder": logger.info(f"convert to float16 ...") - m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize"]) + m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize", "GroupNorm"]) else: logger.info("skip convert vae_decoder to fp16.") From a4c4302d6b16ee09eb919a400156f237bdaa3680 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Jan 2023 22:48:35 +0000 Subject: [PATCH 05/27] Add SplitGelu fusion --- .../python/tools/symbolic_shape_infer.py | 11 +++ .../tools/transformers/fusion_group_norm.py | 6 +- .../tools/transformers/fusion_options.py | 1 + .../tools/transformers/fusion_splitgelu.py | 99 +++++++++++++++++++ .../tools/transformers/onnx_model_unet.py | 5 + 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/fusion_splitgelu.py diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ee8b7923ca62d..b0c07b762e092 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -201,6 +201,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, "GroupNorm": self._infer_GroupNorm, + "SplitGelu": self._infer_SplitGelu, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -436,6 +437,7 @@ def _onnx_infer_single_node(self, node): "PythonOp", "MultiHeadAttention", "GroupNorm", + "SplitGelu", ] if not skip_infer: @@ -2061,6 +2063,15 @@ def _infer_SkipLayerNormalization(self, node): def _infer_GroupNorm(self, node): self._propagate_shape_and_type(node) + def _infer_SplitGelu(self, node): + input_shape = self._get_shape(node, 0) + if input_shape: + output_shape = input_shape + output_shape[2] = int(input_shape[2] / 2) + vi = self.known_vi_[node.output[0]] + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) + def _infer_PythonOp(self, node): output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index ea94baa690e34..6f0874bc1c450 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -99,7 +99,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") - logger.info(f"GroupNorm channels={weight_elements}") + if weight_elements not in [320, 640, 960, 1280, 1920, 2560] + [128, 256, 512]: + logger.info(f"GroupNorm channels={weight_elements}") gamma = helper.make_tensor( name=group_norm_name + "_gamma", @@ -130,7 +131,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): input_name_to_nodes, output_name_to_node, ): - self.node_to_remove.extend([last_node]) + self.nodes_to_remove.extend([last_node]) else: self.nodes_to_remove.extend(subgraph_nodes) @@ -166,3 +167,4 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): new_node.domain = "com.microsoft" self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name + return True diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index f3104d86130c8..17306762c118e 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -37,6 +37,7 @@ def __init__(self, model_type): self.enable_shape_inference = True self.enable_gemm_fast_gelu = False self.enable_group_norm = model_type == "unet" + self.enable_splitgelu = model_type == "unet" self.attention_mask_format = AttentionMaskFormat.AttentionMask def use_raw_attention_mask(self, use_raw_mask=True): diff --git a/onnxruntime/python/tools/transformers/fusion_splitgelu.py b/onnxruntime/python/tools/transformers/fusion_splitgelu.py new file mode 100644 index 0000000000000..6e2e086717aed --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_splitgelu.py @@ -0,0 +1,99 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Dict, Optional + +from fusion_base import Fusion +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionSplitGelu(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "SplitGelu", "Gelu") + + def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + """ + [root] --------------------> Slice ---------------> Mul --> + | ^ ^ + | | | + +----------------------------+---Slice --> Gelu---+ + | | ^ + | |-----| + | | | + | Mul Mul + | ^ ^ + v | | + Shape ---> Gather --> Add --> Div --+ + """ + if gelu_node.output[0] not in input_name_to_nodes: + return + children = input_name_to_nodes[gelu_node.output[0]] + if len(children) != 1 or children[0].op_type != "Mul": + return + mul_after_gelu = children[0] + + slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node) + if slice_before_gelu is None: + return + + if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3: + return + + subgraph_input = slice_before_gelu.input[0] + + start_index_nodes = self.model.match_parent_path( + slice_before_gelu, + ["Div", "Add", "Gather", "Shape"], + [1, 0, 0, 0], + output_name_to_node, # Mul(1) is optional + ) + if start_index_nodes is None: + start_index_nodes = self.model.match_parent_path( + slice_before_gelu, ["Mul", "Div", "Add", "Gather", "Shape"], [1, 0, 0, 0, 0], output_name_to_node + ) + + if start_index_nodes is None or start_index_nodes[-1].input[0] != subgraph_input: + return + + end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node) + + if ( + end_index_nodes is None or end_index_nodes[1] not in start_index_nodes + ): # the Div is parent of both two Mul nodes + return + + slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node) + if slice_before_mul is None: + return + + if ( + slice_before_mul.input[2] != slice_before_gelu.input[1] + ): # end index of slice_before_mul is start index of slice_before_gelu + return + + subgraph_nodes = start_index_nodes + [ + end_index_nodes[0], + mul_after_gelu, + gelu_node, + slice_before_mul, + slice_before_gelu, + ] + subgraph_output = mul_after_gelu.output[0] + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node + ): + logger.info("Skip fuse SplitGelu since it is not safe to fuse the subgraph.") + return + + self.nodes_to_remove.extend(subgraph_nodes) + node_name = self.model.create_node_name("SplitGelu", name_prefix="SplitGelu") + fused_node = helper.make_node("SplitGelu", inputs=[subgraph_input], outputs=[subgraph_output], name=node_name) + fused_node.domain = "com.microsoft" + self.nodes_to_add.append(fused_node) + self.node_name_to_graph_name[node_name] = self.this_graph_name + return True diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index b59db5087df31..19657b4345255 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -9,6 +9,7 @@ from fusion_attention_unet import FusionAttentionUnet from fusion_group_norm import FusionGroupNorm from fusion_options import FusionOptions +from fusion_splitgelu import FusionSplitGelu from onnx import ModelProto from onnx_model_bert import BertOnnxModel @@ -57,6 +58,10 @@ def optimize(self, options: Optional[FusionOptions] = None): group_norm_fusion = FusionGroupNorm(self) group_norm_fusion.apply() + if (options is None) or options.enable_splitgelu: + split_gelu_fusion = FusionSplitGelu(self) + split_gelu_fusion.apply() + if (options is None) or options.enable_attention: self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False) self_attention_fusion.apply() From f722c5aa581e976d4d4362956b4992339984b7ce Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Jan 2023 22:56:05 +0000 Subject: [PATCH 06/27] support float type in GroupNorm --- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../contrib_ops/cuda/diffusion/group_norm.cc | 27 +- .../contrib_ops/cuda/diffusion/group_norm.h | 2 +- .../cuda/diffusion/group_norm_impl.cu | 208 +++++---- .../cuda/diffusion/group_norm_impl.h | 2 +- .../core/graph/contrib_ops/diffusion_defs.cc | 11 +- .../test/contrib_ops/group_norm_op_test.cc | 433 ++++++++++++++++++ 7 files changed, 589 insertions(+), 96 deletions(-) create mode 100644 onnxruntime/test/contrib_ops/group_norm_op_test.cc diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index e054b44f0d53f..ad78e1bd8e960 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -72,6 +72,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); @@ -194,6 +195,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 56bf68285bb65..594d78907a769 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -21,21 +21,24 @@ namespace cuda { GroupNorm); REGISTER_KERNEL_TYPED(MLFloat16); +REGISTER_KERNEL_TYPED(float); using namespace ONNX_NAMESPACE; template -GroupNorm::GroupNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { - ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); +GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); int64_t num_groups; - ORT_ENFORCE(op_kernel_info.GetAttr("groups", &num_groups).IsOK()); + ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); ORT_ENFORCE(num_groups >= 0); num_groups_ = static_cast(num_groups); - - ORT_ENFORCE(op_kernel_info.GetAttr("swish", &swish_).IsOK()); + int64_t activation; + ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); + ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish + use_swish_activation_ = (activation == 1); } template @@ -56,9 +59,9 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[2]) { + if (gamma_dims[0] != input_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); + "Number of channels in gamma and input does not match"); } const auto& beta_dims = beta->Shape().GetDims(); @@ -66,9 +69,9 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[2]) { + if (beta_dims[0] != input_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); + "Number of channels in beta and input does not match"); } int batch_size = static_cast(input_dims[0]); @@ -89,8 +92,8 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { Stream(context), reinterpret_cast(output->MutableData()), reinterpret_cast(input->Data()), - reinterpret_cast(gamma->Data()), - reinterpret_cast(beta->Data()), + gamma->Data(), + beta->Data(), reinterpret_cast(workspace.get()), epsilon_, batch_size, @@ -98,7 +101,7 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { height, width, num_groups_, - swish_); + use_swish_activation_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 85341a60a5e87..099c083084527 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -18,7 +18,7 @@ class GroupNorm final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; private: - bool swish_; + bool use_swish_activation_; float epsilon_; int num_groups_; }; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index dc3b631c58585..7108d6e03d60f 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -39,7 +39,7 @@ struct GroupSums { int32_t flag; // The sum. float sum; - // The sum of squares. + // The sum of squares. float sumSq; }; @@ -53,11 +53,12 @@ struct GroupSumsOp { } }; +template struct GroupNormNHWCParams { // The output buffer. Layout NHWC. - __half* dst; + T* dst; // The input buffer. Layout NHWC. - __half const* src; + T const* src; // The gamma scaling factor. float const* gamma; // The beta term to add in GN. @@ -97,8 +98,37 @@ struct GroupNormNHWCParams { int32_t groupsPerBlock; }; -template -__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); + +template <> +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; +} + +template +__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -123,23 +153,12 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { float sumSq = 0.F; // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - - // Fetch two channels per thread. - __half2 h2(0, 0); - if (ci < params.c) { - h2 = *reinterpret_cast<__half2 const*>(¶ms.src[offset]); + if (ci < params.c) { + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; + UpdateSum(params.src, offset, sum, sumSq); } - - // Extract the two half values. - float2 f2 = __half22float2(h2); - - // Update the sum. - sum += f2.x + f2.y; - // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; } // The group that thread works on and the channel in the group (modulus). @@ -155,7 +174,7 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { // Store the results for the groups in shared memory (to produce coalesced // stores later). - if (cj == params.cPerGroup - 2 /* 2 channels per thread */) { + if (cj == params.cPerGroup - 2) { //2 channels per thread smem[gi] = make_float2(out.sum, out.sumSq); } @@ -178,7 +197,8 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } -void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { +template +void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { // Make sure the values are as we expect. ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); // Make sure a group does not span multiple blocks. @@ -195,28 +215,85 @@ void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { switch (params.cPerBlock) { case 320: - groupNormNHWCSumKernel<160><<>>(params); + groupNormNHWCSumKernel<<>>(params); break; case 480: - groupNormNHWCSumKernel<256><<>>(params); + groupNormNHWCSumKernel<<>>(params); break; case 256: - groupNormNHWCSumKernel<128><<>>(params); + groupNormNHWCSumKernel<<>>(params); break; case 128: - groupNormNHWCSumKernel<64><<>>(params); + groupNormNHWCSumKernel<<>>(params); break; default: ORT_NOT_IMPLEMENTED("Not implemented"); } } -template -__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The instance in the batch. - int32_t ni = blockIdx.z; +template +__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); + +template <> +__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, + float2& gammaF2, float2& betaF2, bool swish) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (swish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); +} + +template <> +__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, + float2& gammaF2, float2& betaF2, bool swish) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (swish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast(&dst[offset]) = f2; +} + +template +__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { // The channel loaded by that thread (2 channels per thread for F16x2). int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + if (ci >= params.c) { + return; + } + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The group that thread works on and the channel in the group (modulus). int32_t gi = ci / params.cPerGroup; @@ -228,11 +305,8 @@ __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { } // Load gamma/beta. - float2 gammaF2, betaF2; - if (ci < params.c) { - gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); - betaF2 = *reinterpret_cast(¶ms.beta[ci]); - } + float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); // Compute the mean. float mean = sum * params.invHWC; @@ -252,36 +326,12 @@ __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; // Fetch two channels per thread. - __half2 h2(0, 0); - if (ci < params.c) { - h2 = *reinterpret_cast<__half2 const*>(¶ms.src[offset]); - } - - // Extract the two half values. - float2 f2 = __half22float2(h2); - - // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; - - // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; - - // Apply Swish if needed. - if (params.withSwish) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - // Store the scaled values. - if (ci < params.c) { - *reinterpret_cast<__half2*>(¶ms.dst[offset]) = __float22half2_rn(f2); - } + computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); } } -void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { +template +void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { // Make sure the dimensions are aligned with what we expect. ORT_ENFORCE(params.c % params.cPerBlock == 0); // Make sure a group does not span multiple blocks. @@ -298,16 +348,16 @@ void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) switch (params.cPerBlock) { case 320: - groupNormNHWCScaleKernel<160><<>>(params); + groupNormNHWCScaleKernel<<>>(params); break; case 480: - groupNormNHWCScaleKernel<256><<>>(params); + groupNormNHWCScaleKernel<<>>(params); break; case 256: - groupNormNHWCScaleKernel<128><<>>(params); + groupNormNHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<64><<>>(params); + groupNormNHWCScaleKernel<<>>(params); break; default: ORT_NOT_IMPLEMENTED("Not implemented"); @@ -346,7 +396,7 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool swish) { + bool use_swish_activation) { if (batch_size > static_cast(kMaxGroupNormBatchSize)) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "only support batch_size <= 32. Got", batch_size); @@ -357,7 +407,7 @@ Status LaunchGroupNormKernel( "only num_groups=32 is supported. Got", num_groups); } - GroupNormNHWCParams params; + GroupNormNHWCParams params; int32_t cPerBlock = 320; int32_t maxBlocksPerHW = 1024; switch (num_channels) { @@ -376,12 +426,12 @@ Status LaunchGroupNormKernel( cPerBlock = 320; } - params.withSwish = bool(swish); - params.dst = static_cast(output); - params.src = static_cast(input); - params.gamma = static_cast(gamma); - params.beta = static_cast(beta); - params.redBuffer = static_cast(workspace); + params.withSwish = use_swish_activation; + params.dst = output; + params.src = input; + params.gamma = gamma; + params.beta = beta; + params.redBuffer = reinterpret_cast(workspace); params.n = batch_size; params.h = height; params.w = width; @@ -397,10 +447,10 @@ Status LaunchGroupNormKernel( params.groupsPerBlock = cPerBlock / params.cPerGroup; cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); - groupNormNHWCSum(params, stream); + groupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); + groupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); @@ -411,6 +461,10 @@ template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, bool swish); +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, + const float* input, const float* gamma, const float* beta, void* workspace, + float epsilon, int batch_size, int num_channels, + int height, int width, int num_groups, bool swish); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index 5595d9ffb5082..347c4624ac9f1 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -35,7 +35,7 @@ Status LaunchGroupNormKernel( int height, // H int width, // W int num_groups, // number of groups - bool swish // Whether there is Swish activation after group normalization + bool use_swish_activation // Whether there is Swish activation after group normalization ); } // namespace cuda diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index 9a53c59e9043e..60f7538ef7056 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -29,7 +29,7 @@ This operator transforms input according to The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group. The weight and bias are per-channel affine transform parameter vectors of size num_channels. -The swish attribute can be used to enable Swish activation after group normalization. +The activation attribute can be used to enable activation after group normalization. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -40,8 +40,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Attr("groups", "The number of groups of channels. It should be a divisor of the number of channels C", AttributeProto::INT) - .Attr("swish", - "Whether use Swish activation after group normalization", + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for Swish", AttributeProto::INT) .Input(0, "X", @@ -50,16 +50,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(1, "gamma", "1D gamma tensor for normalization with shape (C), where C is number of channels", - "T") + "M") .Input(2, "beta", "1D beta tensor for normalization with shape (C), where C is number of channels", - "T") + "M") .Output(0, "Y", "The output tensor of the same shape as X", "T") .TypeConstraint("T", {"tensor(float16)"}, "Constrain input X and output Y types to half tensors.") + //.TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); diff --git a/onnxruntime/test/contrib_ops/group_norm_op_test.cc b/onnxruntime/test/contrib_ops/group_norm_op_test.cc new file mode 100644 index 0000000000000..c3b43f708ccfa --- /dev/null +++ b/onnxruntime/test/contrib_ops/group_norm_op_test.cc @@ -0,0 +1,433 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/provider_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace std; + +namespace onnxruntime { +namespace test { + +TEST(GroupNormTest, GroupNorm_128) { + constexpr int64_t B = 2; + constexpr int64_t C = 128; + constexpr int64_t H = 2; + constexpr int64_t W = 2; + + std::vector dims{B, C, H, W}; + std::vector input_data = { + 0.696469f, 0.286139f, 0.226851f, 0.551315f, 0.719469f, 0.423106f, 0.980764f, 0.684830f, 0.480932f, 0.392118f, + 0.343178f, 0.729050f, 0.438572f, 0.059678f, 0.398044f, 0.737995f, 0.182492f, 0.175452f, 0.531551f, 0.531828f, + 0.634401f, 0.849432f, 0.724455f, 0.611024f, 0.722443f, 0.322959f, 0.361789f, 0.228263f, 0.293714f, 0.630976f, + 0.092105f, 0.433701f, 0.430863f, 0.493685f, 0.425830f, 0.312261f, 0.426351f, 0.893389f, 0.944160f, 0.501837f, + 0.623953f, 0.115618f, 0.317285f, 0.414826f, 0.866309f, 0.250455f, 0.483034f, 0.985560f, 0.519485f, 0.612895f, + 0.120629f, 0.826341f, 0.603060f, 0.545068f, 0.342764f, 0.304121f, 0.417022f, 0.681301f, 0.875457f, 0.510422f, + 0.669314f, 0.585937f, 0.624904f, 0.674689f, 0.842342f, 0.083195f, 0.763683f, 0.243666f, 0.194223f, 0.572457f, + 0.095713f, 0.885327f, 0.627249f, 0.723416f, 0.016129f, 0.594432f, 0.556785f, 0.158960f, 0.153071f, 0.695530f, + 0.318766f, 0.691970f, 0.554383f, 0.388951f, 0.925132f, 0.841670f, 0.357398f, 0.043591f, 0.304768f, 0.398186f, + 0.704959f, 0.995358f, 0.355915f, 0.762548f, 0.593177f, 0.691702f, 0.151127f, 0.398876f, 0.240856f, 0.343456f, + 0.513128f, 0.666625f, 0.105908f, 0.130895f, 0.321981f, 0.661564f, 0.846506f, 0.553257f, 0.854452f, 0.384838f, + 0.316788f, 0.354265f, 0.171082f, 0.829113f, 0.338671f, 0.552370f, 0.578551f, 0.521533f, 0.002688f, 0.988345f, + 0.905342f, 0.207636f, 0.292489f, 0.520010f, 0.901911f, 0.983631f, 0.257542f, 0.564359f, 0.806969f, 0.394370f, + 0.731073f, 0.161069f, 0.600699f, 0.865864f, 0.983522f, 0.079366f, 0.428347f, 0.204543f, 0.450636f, 0.547764f, + 0.093327f, 0.296861f, 0.927584f, 0.569004f, 0.457412f, 0.753526f, 0.741862f, 0.048579f, 0.708697f, 0.839243f, + 0.165938f, 0.780998f, 0.286537f, 0.306470f, 0.665261f, 0.111392f, 0.664872f, 0.887857f, 0.696311f, 0.440328f, + 0.438214f, 0.765096f, 0.565642f, 0.084904f, 0.582671f, 0.814844f, 0.337066f, 0.927577f, 0.750717f, 0.574064f, + 0.751644f, 0.079149f, 0.859389f, 0.821504f, 0.909872f, 0.128631f, 0.081780f, 0.138416f, 0.399379f, 0.424307f, + 0.562218f, 0.122244f, 0.201400f, 0.811644f, 0.467988f, 0.807938f, 0.007426f, 0.551593f, 0.931932f, 0.582175f, + 0.206096f, 0.717758f, 0.378986f, 0.668384f, 0.029320f, 0.635900f, 0.032198f, 0.744781f, 0.472913f, 0.121754f, + 0.542636f, 0.066774f, 0.653365f, 0.996086f, 0.769397f, 0.573774f, 0.102635f, 0.699834f, 0.661168f, 0.049097f, + 0.792299f, 0.518717f, 0.425868f, 0.788187f, 0.411569f, 0.481026f, 0.181629f, 0.321319f, 0.845533f, 0.186904f, + 0.417291f, 0.989035f, 0.236600f, 0.916832f, 0.918397f, 0.091296f, 0.463653f, 0.502216f, 0.313669f, 0.047340f, + 0.241686f, 0.095530f, 0.238250f, 0.807791f, 0.894978f, 0.043223f, 0.301947f, 0.980582f, 0.539505f, 0.626309f, + 0.005545f, 0.484909f, 0.988329f, 0.375186f, 0.097038f, 0.461909f, 0.963004f, 0.341831f, 0.798923f, 0.798846f, + 0.208248f, 0.443368f, 0.715601f, 0.410520f, 0.191007f, 0.967494f, 0.650750f, 0.865460f, 0.025242f, 0.266906f, + 0.502071f, 0.067449f, 0.993033f, 0.236462f, 0.374292f, 0.214012f, 0.105446f, 0.232480f, 0.300610f, 0.634442f, + 0.281235f, 0.362277f, 0.005943f, 0.365719f, 0.533886f, 0.162016f, 0.597433f, 0.293152f, 0.632050f, 0.026197f, + 0.887593f, 0.016119f, 0.126958f, 0.777162f, 0.045895f, 0.710999f, 0.971046f, 0.871683f, 0.710162f, 0.958510f, + 0.429813f, 0.872879f, 0.355958f, 0.929764f, 0.148778f, 0.940029f, 0.832716f, 0.846055f, 0.123923f, 0.596487f, + 0.016392f, 0.721184f, 0.007738f, 0.084822f, 0.225498f, 0.875125f, 0.363576f, 0.539960f, 0.568103f, 0.225463f, + 0.572147f, 0.660952f, 0.298245f, 0.418627f, 0.453089f, 0.932351f, 0.587494f, 0.948252f, 0.556035f, 0.500561f, + 0.003532f, 0.480889f, 0.927455f, 0.198366f, 0.052091f, 0.406779f, 0.372396f, 0.857153f, 0.026611f, 0.920149f, + 0.680903f, 0.904226f, 0.607529f, 0.811953f, 0.335544f, 0.349566f, 0.389874f, 0.754797f, 0.369291f, 0.242220f, + 0.937668f, 0.908011f, 0.348797f, 0.634638f, 0.273842f, 0.206115f, 0.336340f, 0.327100f, 0.882276f, 0.822304f, + 0.709623f, 0.959345f, 0.422543f, 0.245033f, 0.117398f, 0.301053f, 0.145264f, 0.092186f, 0.602932f, 0.364187f, + 0.564570f, 0.191336f, 0.676906f, 0.215505f, 0.278024f, 0.741760f, 0.559738f, 0.334836f, 0.542989f, 0.693985f, + 0.912132f, 0.580713f, 0.232686f, 0.746698f, 0.777769f, 0.200401f, 0.820574f, 0.464935f, 0.779767f, 0.237478f, + 0.332580f, 0.953697f, 0.657815f, 0.772878f, 0.688374f, 0.204304f, 0.470689f, 0.808964f, 0.675035f, 0.006028f, + 0.087408f, 0.346795f, 0.944366f, 0.491190f, 0.270176f, 0.360424f, 0.210653f, 0.421200f, 0.218035f, 0.845753f, + 0.456271f, 0.279802f, 0.932892f, 0.314351f, 0.909715f, 0.043418f, 0.707115f, 0.483889f, 0.444221f, 0.036323f, + 0.040683f, 0.332754f, 0.947120f, 0.617660f, 0.368875f, 0.611977f, 0.206132f, 0.165066f, 0.361817f, 0.863353f, + 0.509402f, 0.296902f, 0.950252f, 0.815966f, 0.322974f, 0.972098f, 0.987351f, 0.408660f, 0.655923f, 0.405653f, + 0.257348f, 0.082653f, 0.263610f, 0.271480f, 0.398639f, 0.184886f, 0.953818f, 0.102880f, 0.625209f, 0.441697f, + 0.423518f, 0.371992f, 0.868315f, 0.280477f, 0.020576f, 0.918097f, 0.864480f, 0.276902f, 0.523488f, 0.109088f, + 0.093427f, 0.837466f, 0.410266f, 0.661717f, 0.943201f, 0.245131f, 0.013160f, 0.024148f, 0.709386f, 0.924552f, + 0.467330f, 0.375109f, 0.542860f, 0.858917f, 0.652154f, 0.232980f, 0.774580f, 0.134613f, 0.165560f, 0.612682f, + 0.238783f, 0.704779f, 0.349519f, 0.277424f, 0.998918f, 0.040616f, 0.645823f, 0.038700f, 0.760210f, 0.230090f, + 0.089832f, 0.648450f, 0.732601f, 0.678095f, 0.051901f, 0.294307f, 0.451088f, 0.287103f, 0.810513f, 0.131115f, + 0.612179f, 0.988215f, 0.902557f, 0.222157f, 0.000082f, 0.980597f, 0.882713f, 0.919472f, 0.415504f, 0.744615f, + 0.212831f, 0.392304f, 0.851548f, 0.127612f, 0.893865f, 0.496508f, 0.426096f, 0.305646f, 0.916849f, 0.517623f, + 0.804026f, 0.857652f, 0.922382f, 0.303381f, 0.339811f, 0.595074f, 0.441324f, 0.932843f, 0.397564f, 0.477778f, + 0.617186f, 0.404739f, 0.992478f, 0.098851f, 0.220603f, 0.322655f, 0.147723f, 0.284219f, 0.779245f, 0.522892f, + 0.033954f, 0.982623f, 0.616006f, 0.058939f, 0.661169f, 0.378369f, 0.135673f, 0.563665f, 0.727080f, 0.671127f, + 0.247513f, 0.524866f, 0.537663f, 0.716803f, 0.359867f, 0.797733f, 0.627922f, 0.038332f, 0.546479f, 0.861912f, + 0.567574f, 0.175828f, 0.510376f, 0.756946f, 0.110105f, 0.817099f, 0.167482f, 0.534076f, 0.385743f, 0.248624f, + 0.647433f, 0.037392f, 0.760046f, 0.526941f, 0.875771f, 0.520718f, 0.035033f, 0.143601f, 0.795605f, 0.491976f, + 0.441879f, 0.318435f, 0.284549f, 0.965886f, 0.432969f, 0.884003f, 0.648163f, 0.858428f, 0.852450f, 0.956312f, + 0.697942f, 0.805397f, 0.733128f, 0.605227f, 0.717354f, 0.715750f, 0.040908f, 0.516111f, 0.792651f, 0.242962f, + 0.465148f, 0.434986f, 0.402787f, 0.121840f, 0.525712f, 0.446248f, 0.663393f, 0.549413f, 0.027543f, 0.031918f, + 0.701360f, 0.707581f, 0.959939f, 0.876705f, 0.468060f, 0.625907f, 0.457182f, 0.222946f, 0.376677f, 0.103884f, + 0.666527f, 0.192030f, 0.475468f, 0.967437f, 0.031669f, 0.151730f, 0.298579f, 0.941807f, 0.908842f, 0.162001f, + 0.981118f, 0.750748f, 0.539977f, 0.931703f, 0.880607f, 0.391316f, 0.656343f, 0.647385f, 0.326968f, 0.179390f, + 0.466810f, 0.263281f, 0.355065f, 0.954144f, 0.461138f, 0.684891f, 0.336230f, 0.995861f, 0.658768f, 0.196009f, + 0.098184f, 0.943181f, 0.944778f, 0.621328f, 0.016991f, 0.225535f, 0.801277f, 0.875460f, 0.453990f, 0.365521f, + 0.274225f, 0.116971f, 0.115745f, 0.952603f, 0.808626f, 0.164779f, 0.207050f, 0.655552f, 0.764664f, 0.810315f, + 0.163338f, 0.984128f, 0.227802f, 0.589415f, 0.587616f, 0.967362f, 0.657667f, 0.584904f, 0.518773f, 0.764658f, + 0.106055f, 0.002092f, 0.952489f, 0.498658f, 0.328335f, 0.368053f, 0.803843f, 0.382370f, 0.770169f, 0.440462f, + 0.844077f, 0.076204f, 0.481128f, 0.466850f, 0.264328f, 0.943615f, 0.905028f, 0.443596f, 0.097160f, 0.206783f, + 0.271492f, 0.484220f, 0.338377f, 0.774136f, 0.476027f, 0.870371f, 0.995782f, 0.219836f, 0.611671f, 0.847502f, + 0.945237f, 0.290086f, 0.727043f, 0.015016f, 0.879142f, 0.063939f, 0.733395f, 0.994610f, 0.501190f, 0.209334f, + 0.594644f, 0.624150f, 0.668073f, 0.172612f, 0.898713f, 0.620991f, 0.043569f, 0.684041f, 0.196084f, 0.027341f, + 0.550953f, 0.813314f, 0.859941f, 0.103521f, 0.663043f, 0.710075f, 0.294517f, 0.971364f, 0.278687f, 0.069982f, + 0.519280f, 0.694315f, 0.244660f, 0.338582f, 0.563628f, 0.886678f, 0.747326f, 0.209592f, 0.251777f, 0.523881f, + 0.768959f, 0.618762f, 0.501324f, 0.597125f, 0.756060f, 0.537080f, 0.897753f, 0.947067f, 0.915355f, 0.754518f, + 0.246321f, 0.385271f, 0.280000f, 0.657660f, 0.324222f, 0.754392f, 0.113509f, 0.775365f, 0.585902f, 0.835389f, + 0.430876f, 0.624964f, 0.554412f, 0.975671f, 0.755474f, 0.544813f, 0.174032f, 0.904114f, 0.205838f, 0.650043f, + 0.936472f, 0.223580f, 0.225924f, 0.851819f, 0.827655f, 0.351703f, 0.265096f, 0.127388f, 0.987936f, 0.835343f, + 0.899392f, 0.513679f, 0.114385f, 0.052580f, 0.330582f, 0.920330f, 0.947582f, 0.841164f, 0.158679f, 0.419923f, + 0.246243f, 0.205350f, 0.684826f, 0.486112f, 0.324910f, 0.100214f, 0.544763f, 0.347025f, 0.391096f, 0.310509f, + 0.387195f, 0.555860f, 0.014144f, 0.847647f, 0.921920f, 0.550530f, 0.268021f, 0.990239f, 0.383194f, 0.693655f, + 0.689953f, 0.434309f, 0.199158f, 0.966579f, 0.063691f, 0.485149f, 0.220731f, 0.293974f, 0.828527f, 0.367266f, + 0.083348f, 0.196309f, 0.860373f, 0.977029f, 0.267982f, 0.675409f, 0.081199f, 0.723466f, 0.416437f, 0.918160f, + 0.311536f, 0.941467f, 0.503247f, 0.348893f, 0.647020f, 0.249746f, 0.229764f, 0.196346f, 0.959900f, 0.492914f, + 0.751615f, 0.473992f, 0.587540f, 0.584139f, 0.979886f, 0.668433f, 0.239769f, 0.015198f, 0.218682f, 0.455520f, + 0.393420f, 0.812326f, 0.785557f, 0.089096f, 0.952011f, 0.527457f, 0.596404f, 0.405057f, 0.649501f, 0.871326f, + 0.673936f, 0.970099f, 0.701122f, 0.821721f, 0.045040f, 0.672699f, 0.654753f, 0.101746f, 0.842387f, 0.614172f, + 0.098328f, 0.594467f, 0.478416f, 0.233294f, 0.019756f, 0.365567f, 0.619851f, 0.329279f, 0.307255f, 0.751121f, + 0.758625f, 0.718766f, 0.101182f, 0.516166f, 0.557799f, 0.744805f, 0.903178f, 0.369039f, 0.428663f, 0.732767f, + 0.662636f, 0.557870f, 0.350140f, 0.195352f, 0.183807f, 0.081583f, 0.081201f, 0.845798f, 0.383673f, 0.060740f, + 0.896426f, 0.223270f, 0.268124f, 0.194498f, 0.967501f, 0.112540f, 0.722163f, 0.932089f, 0.668001f, 0.858727f, + 0.242447f, 0.673928f, 0.700871f, 0.458333f, 0.870546f, 0.694386f, 0.894878f, 0.753204f, 0.520290f, 0.498688f, + 0.453728f, 0.021647f, 0.535141f, 0.422973f, 0.157534f, 0.119070f, 0.449352f, 0.039913f, 0.986580f, 0.378121f, + 0.382109f, 0.051126f, 0.426672f, 0.015745f, 0.030094f, 0.339099f, 0.820969f, 0.458821f, 0.014841f, 0.163220f, + 0.739923f, 0.738294f, 0.754523f, 0.351669f, 0.352277f, 0.802076f, 0.398138f, 0.727191f, 0.581123f, 0.364342f, + 0.080007f, 0.116125f, 0.889559f, 0.452341f, 0.994005f, 0.363897f, 0.249954f, 0.350539f, 0.343086f, 0.637357f, + 0.012738f, 0.763269f, 0.416415f, 0.432239f, 0.481115f, 0.449212f, 0.497471f, 0.345904f, 0.453346f, 0.404651f, + 0.518243f, 0.623269f, 0.241041f, 0.508437f, 0.594622f, 0.016948f, 0.520494f, 0.239293f, 0.404539f, 0.826530f, + 0.326236f, 0.483217f, 0.024741f, 0.308751f, 0.639721f, 0.315162f, 0.205798f, 0.290656f, 0.954378f, 0.086802f, + 0.463358f, 0.058387f, 0.538658f, 0.146036f, 0.634085f, 0.264397f, 0.690915f, 0.347146f, 0.004168f, 0.294895f, + 0.081894f, 0.495040f, 0.288890f, 0.639992f, 0.499936f, 0.036045f, 0.318634f, 0.489059f, 0.572204f, 0.104871f, + 0.649971f, 0.343698f, 0.182921f, 0.805327f, 0.068623f, 0.929770f, 0.706266f, 0.475591f, 0.011161f, 0.390125f, + 0.645798f, 0.858913f, 0.617764f, 0.397707f}; + + std::vector gamma_data = { + 0.447359f, 0.873295f, 0.351357f, 0.065158f, 0.442673f, 0.998459f, 0.379773f, 0.193055f, 0.045130f, 0.170969f, + 0.324064f, 0.574278f, 0.665588f, 0.042819f, 0.936180f, 0.235638f, 0.149062f, 0.530829f, 0.677586f, 0.307253f, + 0.669441f, 0.294294f, 0.902172f, 0.880695f, 0.071194f, 0.150403f, 0.698059f, 0.000120f, 0.821814f, 0.356240f, + 0.744620f, 0.044237f, 0.209264f, 0.070805f, 0.179824f, 0.384421f, 0.491552f, 0.916091f, 0.627174f, 0.706480f, + 0.082111f, 0.286787f, 0.991732f, 0.560422f, 0.787817f, 0.032482f, 0.084076f, 0.109233f, 0.015286f, 0.921979f, + 0.253635f, 0.996569f, 0.738130f, 0.250611f, 0.991805f, 0.868534f, 0.164998f, 0.185322f, 0.680186f, 0.078280f, + 0.584525f, 0.066603f, 0.221298f, 0.948440f, 0.498572f, 0.573713f, 0.269683f, 0.440062f, 0.133002f, 0.516616f, + 0.053956f, 0.048249f, 0.679648f, 0.054982f, 0.521284f, 0.266026f, 0.187694f, 0.573319f, 0.296463f, 0.456382f, + 0.138974f, 0.126486f, 0.106529f, 0.071560f, 0.553714f, 0.756005f, 0.792367f, 0.957845f, 0.168392f, 0.135619f, + 0.469955f, 0.861008f, 0.767069f, 0.558178f, 0.156783f, 0.391263f, 0.719346f, 0.373413f, 0.039119f, 0.583884f, + 0.720135f, 0.714771f, 0.164866f, 0.335992f, 0.409172f, 0.420481f, 0.114158f, 0.385532f, 0.506632f, 0.710561f, + 0.569448f, 0.404931f, 0.927597f, 0.598084f, 0.974791f, 0.867376f, 0.673626f, 0.899313f, 0.991240f, 0.220877f, + 0.691057f, 0.918779f, 0.017400f, 0.799489f, 0.089403f, 0.916554f, 0.612013f, 0.162069f}; + + std::vector beta_data = { + 0.039410f, 0.827821f, 0.139492f, 0.939541f, 0.090865f, 0.837978f, 0.423533f, 0.872735f, 0.768574f, 0.852882f, 0.470242f, 0.713768f, 0.318668f, 0.047173f, + 0.232400f, 0.001362f, 0.363028f, 0.493829f, 0.019407f, 0.007730f, 0.686464f, 0.100436f, 0.073846f, 0.495598f, 0.718159f, 0.977165f, 0.295397f, 0.117518f, + 0.068537f, 0.207511f, 0.100055f, 0.003384f, 0.285074f, 0.164207f, 0.018250f, 0.354632f, 0.825916f, 0.303662f, 0.710100f, 0.728735f, 0.025556f, 0.961785f, + 0.139009f, 0.717465f, 0.379443f, 0.868223f, 0.994961f, 0.193323f, 0.819456f, 0.505503f, 0.965431f, 0.658089f, 0.593238f, 0.229523f, 0.718700f, 0.288201f, + 0.845759f, 0.977264f, 0.007793f, 0.954633f, 0.358460f, 0.488316f, 0.924086f, 0.775958f, 0.243222f, 0.096853f, 0.841226f, 0.747060f, 0.858339f, 0.384041f, + 0.492114f, 0.465019f, 0.314722f, 0.335672f, 0.718649f, 0.753071f, 0.863854f, 0.844902f, 0.753938f, 0.332778f, 0.710046f, 0.972624f, 0.916240f, 0.971488f, + 0.036208f, 0.611599f, 0.215343f, 0.246560f, 0.844061f, 0.750192f, 0.328802f, 0.519915f, 0.188330f, 0.003827f, 0.899958f, 0.709642f, 0.528818f, 0.054099f, + 0.420840f, 0.380042f, 0.171547f, 0.156188f, 0.173178f, 0.596836f, 0.124704f, 0.238549f, 0.946272f, 0.219462f, 0.763857f, 0.598040f, 0.413157f, 0.595286f, + 0.133620f, 0.484188f, 0.972134f, 0.427721f, 0.242881f, 0.927507f, 0.610774f, 0.727857f, 0.543405f, 0.011202f, 0.755700f, 0.978697f, 0.716188f, 0.808757f, + 0.851587f, 0.999201f}; + + std::vector norm_data = { + 0.406306f, -0.397960f, -0.514167f, 0.121796f, 1.632045f, 0.498094f, 2.631821f, 1.499508f, 0.095849f, + -0.040874f, -0.116213f, 0.477808f, 0.919355f, 0.811189f, 0.907785f, 1.004834f, -0.458834f, -0.472885f, + 0.237840f, 0.238391f, 1.632483f, 2.600490f, 2.037882f, 1.527244f, 0.876482f, 0.192458f, 0.258945f, + 0.030314f, 0.729815f, 1.023374f, 0.554331f, 0.851662f, 0.750835f, 0.762038f, 0.749937f, 0.729685f, + 0.782631f, 1.098150f, 1.132450f, 0.833627f, 0.590117f, -0.060817f, 0.197422f, 0.322326f, 1.476163f, + 0.078648f, 0.606424f, 1.746771f, 0.183714f, 0.518953f, -1.247748f, 1.284994f, 0.057787f, 0.044398f, + -0.002311f, -0.011233f, -0.474648f, 0.859423f, 1.839518f, -0.003167f, 0.143954f, 0.038016f, 0.087527f, + 0.150784f, 0.561618f, 0.176986f, 0.521764f, 0.258291f, 0.031635f, 0.714081f, -0.146106f, 1.278590f, + 0.426744f, 0.648229f, -0.980738f, 0.351162f, 0.118848f, -0.296623f, -0.302773f, 0.263747f, 0.054676f, + 1.040141f, 0.676835f, 0.240001f, 0.526575f, 0.429690f, -0.132461f, -0.496733f, -0.827396f, -0.494966f, + 0.596699f, 1.630098f, -0.206514f, 1.206059f, 0.617694f, 0.959952f, 0.631899f, 0.709146f, 0.659876f, + 0.691867f, 1.033381f, 1.134488f, 0.765150f, 0.781609f, -0.028056f, 1.010104f, 1.575500f, 0.678993f, + 0.117742f, 0.117495f, 0.117460f, 0.117479f, -0.928939f, 0.857719f, -0.473908f, 0.106319f, 0.254703f, + 0.187595f, -0.423069f, 0.737017f, 1.002641f, -0.713799f, -0.505049f, 0.054679f, 0.056505f, 0.068448f, + -0.037672f, 0.007170f, 0.502409f, 0.201653f, 0.447086f, 0.031594f, 0.186869f, 0.252268f, 0.281287f, + 0.058290f, -0.032152f, -0.172338f, -0.018190f, 0.042648f, -0.201724f, 0.070818f, 0.915389f, 0.435231f, + 0.683548f, 1.228964f, 1.207481f, -0.069486f, 0.900928f, 1.349056f, -0.962214f, 1.149115f, 0.126877f, + 0.173722f, 1.016921f, -0.284731f, 1.073324f, 1.663625f, 1.156551f, 0.478892f, -0.017409f, 0.077027f, + 0.019404f, -0.119480f, 0.957481f, 1.191751f, 0.709657f, 1.305503f, 0.710492f, 0.094092f, 0.713726f, + -1.632824f, 1.254686f, 1.179984f, 1.354227f, -0.186219f, -0.620889f, -0.462022f, 0.270004f, 0.339929f, + 0.882544f, 0.831658f, 0.840813f, 0.911391f, 1.003820f, 1.105588f, 0.865947f, 1.028848f, 0.385277f, + 0.249245f, 0.102975f, 0.301977f, 0.814893f, 0.829719f, 0.796979f, 0.828055f, -0.841305f, 1.360636f, + 0.520542f, -0.564568f, 1.028838f, 0.624319f, 1.122967f, 1.414307f, 1.664626f, 1.011229f, -0.562413f, + 1.432279f, 0.982238f, -0.634975f, 1.328713f, 0.605853f, 0.150513f, 0.475544f, 0.137686f, 0.199995f, + -0.461095f, 0.034839f, 1.895931f, -0.442368f, -0.012286f, 1.765260f, -0.574054f, 1.540784f, 1.094831f, + 0.660444f, 0.856002f, 0.876256f, 0.900296f, 0.743193f, 0.857834f, 0.771619f, -0.437987f, 0.795097f, + 0.983861f, -0.860229f, 0.919201f, 1.088295f, 0.978393f, 1.000022f, -0.604762f, 0.300263f, 1.250703f, + 0.093107f, 0.398245f, 0.476736f, 0.584533f, 0.450905f, 1.126501f, 1.126446f, 0.704302f, 0.872359f, + 1.388226f, 0.453643f, -0.218810f, 2.159872f, 0.740287f, 1.137416f, -0.416660f, 0.030324f, 0.352386f, + -0.572652f, 1.397336f, -0.212928f, 0.833504f, 0.673148f, 0.564530f, 0.691624f, 0.614170f, 1.159168f, + 0.582539f, 0.714844f, 0.687727f, 0.829472f, 0.895726f, 0.749217f, 0.626510f, 0.160861f, 0.679485f, + -0.247668f, 0.563813f, 0.424527f, 0.442242f, 0.546163f, 0.408836f, 0.503895f, 0.541062f, 0.526861f, + 0.651389f, 1.131327f, 0.109609f, 0.965844f, 0.307533f, 0.397239f, 0.275143f, 0.398844f, 1.158524f, + 1.178295f, 0.107930f, 0.808378f, 0.360064f, 0.893187f, 0.353517f, 0.411826f, 0.588918f, 1.147333f, + 0.707609f, 0.859227f, 0.904664f, 0.005007f, 0.915281f, 1.148453f, 0.418446f, 0.581892f, 0.628682f, + 1.279391f, 0.420879f, 1.174909f, 0.355126f, 0.239180f, 0.495571f, 0.703488f, 0.897993f, 0.580433f, + 0.796672f, 0.937277f, 0.923647f, 1.115814f, 0.759542f, 1.057870f, 0.977992f, 1.052553f, 0.996513f, + 1.042361f, 0.935513f, 0.938658f, -0.328335f, 0.414783f, -0.370250f, -0.629015f, 1.636925f, 1.554468f, + -0.000332f, 0.794400f, -0.644444f, -0.841804f, -0.462323f, -0.489248f, 1.350502f, 1.139242f, 0.742310f, + 1.621988f, 0.891792f, 0.742398f, 0.634979f, 0.789545f, 0.600690f, 0.564714f, 0.910902f, 0.749079f, + 0.795602f, -0.081046f, 1.059454f, -0.024277f, 0.142066f, 2.137630f, 1.354346f, 0.386545f, -0.015730f, + 0.467942f, 1.166715f, 0.105109f, -0.867947f, 0.330163f, 0.402587f, -0.943201f, 1.039989f, 0.807147f, + 1.013271f, 0.658228f, 0.261774f, 1.276604f, 0.793169f, 0.981167f, 1.182381f, -0.094400f, 0.608214f, + 1.500447f, 0.375100f, -0.540889f, -0.429466f, -0.074319f, 0.493101f, 0.428099f, 0.396397f, 0.409342f, + -0.112225f, 0.338536f, -0.096419f, 1.247461f, 0.136779f, -0.296175f, 1.306138f, -0.211410f, 1.225890f, + -0.883684f, 0.732527f, 0.188935f, 0.158450f, -0.070659f, -0.068210f, 0.095841f, 1.142486f, 0.765356f, + 0.480573f, 0.758850f, -0.296101f, -0.351806f, -0.084915f, 0.595416f, 0.228868f, -0.067355f, 0.843406f, + 0.656214f, 0.873088f, 1.118756f, 1.124528f, 0.905517f, 0.397857f, 0.077982f, -0.111570f, -0.334851f, + 0.432766f, 0.446440f, 0.667385f, 0.295979f, 1.815673f, -0.258010f, 1.014872f, 0.567667f, 0.353312f, + 0.252682f, 1.221989f, 0.073956f, -0.006854f, 1.239576f, 1.165116f, 0.349117f, 0.251850f, -0.979634f, + -1.026174f, 1.184909f, 0.343477f, 0.825275f, 1.364619f, 0.027066f, -0.497336f, -0.463020f, 1.676924f, + 2.348872f, 0.382225f, 0.125961f, 0.592108f, 1.470366f, 0.758787f, -0.208515f, 1.041303f, -0.435509f, + 0.117172f, 1.494655f, 0.342757f, 1.778383f, 0.342274f, 0.097464f, 2.547432f, -0.706661f, 0.892228f, + 0.432844f, 0.978781f, 0.577661f, -0.293386f, 0.867343f, 1.042198f, 0.928943f, -1.206122f, -0.536458f, + -0.103338f, -0.556358f, 0.772336f, 0.736790f, 0.761959f, 0.781633f, 1.964310f, 0.328702f, -0.205143f, + 2.151912f, 0.807267f, 0.819557f, 0.651057f, 0.761094f, -0.553660f, 0.061518f, 1.635670f, -0.845767f, + 1.500599f, 0.591131f, 0.429972f, 0.154289f, 1.184999f, 0.943027f, 1.116617f, 1.149119f, 0.798352f, + -0.237060f, -0.176123f, 0.250859f, 0.738550f, 2.343516f, 0.595660f, 0.857584f, 0.334614f, 0.055512f, + 0.827656f, -0.346350f, 0.879107f, 0.903969f, 0.861351f, 0.894605f, 0.544361f, 0.112821f, -0.710248f, + 0.886723f, 1.241048f, -0.874084f, 1.412525f, 0.338762f, -0.116848f, 0.501252f, 0.737254f, 0.656447f, + 0.680143f, 0.883760f, 0.893155f, 1.024669f, 0.749525f, 0.825862f, 0.796258f, 0.693469f, 0.903967f, + 1.112298f, 0.917900f, 0.659168f, 0.521876f, 0.830550f, 0.020787f, 0.905854f, 0.044571f, 0.857847f, + 0.528776f, 0.224581f, 0.636013f, -0.774066f, 0.896313f, 0.357502f, 0.101543f, 0.048746f, -0.023476f, + -0.007332f, 1.160492f, 0.173347f, 0.010474f, -0.390864f, -0.183245f, 0.374310f, -0.061789f, 0.307303f, + 0.374511f, 0.508790f, 0.504972f, 0.571301f, 0.647929f, 0.892303f, 0.727948f, 0.437075f, 0.272462f, + 0.267807f, -1.691226f, -0.311736f, 0.221596f, -0.501987f, -0.209513f, -0.249217f, 0.477392f, -0.221902f, + 0.783358f, 0.585570f, 0.293685f, 0.168966f, -0.402073f, -0.397286f, 0.793616f, 0.814484f, 1.660988f, + 1.381788f, 0.434287f, 0.951160f, 0.398667f, -0.368342f, 0.685965f, 0.628689f, 0.746822f, 0.647196f, + 0.952972f, 1.171188f, 0.756122f, 0.809376f, -0.181046f, 1.143145f, 1.075280f, -0.462215f, 0.117678f, + 0.117596f, 0.117522f, 0.117660f, 1.207595f, -0.374746f, 0.482337f, 0.453367f, -0.074850f, -0.281733f, + 0.121187f, -0.164130f, -0.407813f, 1.347597f, -0.097000f, 0.558638f, -0.030066f, 0.084762f, 0.026081f, + -0.054476f, 0.048566f, 0.563618f, 0.564591f, 0.367439f, 0.067439f, 0.110448f, 0.229187f, 0.244487f, + 0.001379f, -0.044959f, -0.092778f, -0.175144f, -0.060172f, 0.876871f, 0.715658f, -0.005267f, 0.280818f, + 1.021856f, 1.202137f, 1.277564f, -0.846823f, 1.680601f, -0.648320f, 0.465179f, 0.816884f, 1.617434f, + 0.964561f, 0.811168f, 0.685541f, 1.269441f, -0.294534f, -0.541415f, 0.148579f, 0.006120f, -0.047344f, + -0.034877f, 1.228496f, 0.766407f, 1.191577f, 0.830097f, 1.213856f, -1.697397f, -0.162200f, -0.216335f, + 0.082768f, 1.538109f, 1.455440f, 0.466843f, -0.675884f, -0.396112f, -0.230969f, 0.311936f, 0.850093f, + 0.895946f, 0.864577f, 0.906072f, 1.127087f, 0.915749f, 1.022470f, 1.086701f, 0.347097f, 0.115267f, + 0.269888f, 0.017932f, 0.837999f, 0.798699f, 0.830973f, 0.843566f, 0.524987f, -0.323668f, 0.796731f, + 0.882529f, 1.104285f, 0.707952f, 1.288781f, 1.066624f, -0.759169f, 1.253857f, -0.279808f, -0.810174f, + 0.635460f, 1.336810f, 1.461457f, -0.560630f, 0.345593f, 0.388281f, 0.011112f, 0.625432f, -0.202532f, + -0.952190f, 0.661665f, 1.290380f, -0.625566f, -0.330132f, 0.377751f, 1.393908f, 0.947332f, 0.567214f, + 0.597034f, 0.789381f, 1.108524f, 0.989273f, 0.896032f, 0.972095f, 0.451968f, -0.186156f, 0.864871f, + 1.008577f, 1.059174f, 1.005235f, 0.834800f, 0.881400f, -0.345810f, 0.538783f, -0.242229f, 0.765357f, + 0.363634f, 0.540277f, 0.489711f, 0.556296f, 0.791247f, 0.963361f, 0.900796f, 1.274361f, 1.440297f, + 0.639664f, -0.769517f, 2.005213f, -0.205800f, 0.462482f, 0.893398f, -0.179109f, -0.385072f, 0.698468f, + 0.656636f, -0.167324f, 0.646567f, 0.534505f, 1.234794f, 1.110618f, 1.271695f, 0.759512f, 0.229293f, + 0.147224f, 0.794720f, 1.099447f, 1.113528f, 1.058541f, -0.208087f, 0.316237f, -0.032344f, -0.114418f, + 0.540560f, 0.498906f, 0.465116f, 0.418016f, 0.482087f, 0.445022f, 0.453282f, 0.438177f, -0.006379f, + 0.377702f, -0.855888f, 1.042157f, 0.408202f, 0.339785f, 0.287742f, 0.420788f, 0.465379f, 1.007626f, + 1.001159f, 0.554656f, 0.459783f, 1.143811f, 0.339036f, 0.714696f, 0.691498f, 0.735108f, 1.053392f, + 0.778748f, 0.068571f, 0.274017f, 1.481772f, 1.693937f, 0.526139f, 0.909311f, 0.350476f, 0.954506f, + 0.197028f, 0.923411f, 0.045156f, 0.957155f, 0.714096f, 0.633157f, 0.789485f, 0.581167f, 0.845790f, + 0.829842f, 1.194247f, 0.971378f, 1.019175f, 0.907585f, 0.953225f, 0.951858f, 1.102269f, 1.018174f, + 0.902432f, 0.841796f, -0.858393f, -0.330711f, -0.469070f, 0.464267f, 1.114611f, -1.004036f, 1.620967f, + 0.329466f, 0.139467f, -0.470611f, 0.308757f, 1.016010f, 0.453660f, 1.595124f, 0.558440f, 1.023249f, + 0.601039f, 1.007291f, 0.995676f, 0.637742f, 0.970108f, 0.851145f, 0.582246f, 0.840873f, 0.433405f, + -0.009376f, -0.395102f, 0.229559f, 1.179632f, 0.217997f, 0.145108f, 1.614064f, 1.010146f, 0.887566f, + -1.011727f, 0.264498f, 0.152422f, 0.570916f, 0.925334f, -0.269998f, 0.860524f, 1.051678f, 1.007595f, + 0.941741f, 0.488055f, 0.245246f, 0.227135f, 0.066780f, -0.402708f, 1.265329f, 0.257161f, -0.447346f, + 0.493756f, -0.268568f, -0.217773f, -0.301152f, 0.475332f, 0.373900f, 0.446225f, 0.471130f, 0.663021f, + 1.000752f, -0.090537f, 0.673516f, 0.781955f, 0.128213f, 1.239298f, 0.764475f, 1.281084f, 0.902059f, + 0.278935f, 0.221142f, 0.160415f, -0.106214f, 0.210654f, 0.141437f, 0.198334f, 0.149962f, 0.565323f, + 0.050416f, 0.888878f, 0.074347f, 0.079686f, -0.363394f, 0.253592f, -0.311712f, -0.291973f, 0.133119f, + 1.097622f, 0.962363f, 0.796541f, 0.851959f, 0.628367f, 0.626313f, 0.646783f, 0.138650f, 0.510147f, + 1.394106f, 0.600274f, 1.246940f, 0.872970f, 0.275462f, -0.508244f, -0.408690f, 1.314789f, 0.349021f, + 1.545499f, 0.153658f, 0.231785f, 0.389777f, 0.378070f, 0.840290f, -1.853665f, 1.786896f, 0.104429f, + 0.181189f, 0.667719f, 0.567943f, 0.718873f, 0.244843f, 1.129714f, 0.881495f, 1.460520f, 1.995885f, + -0.395025f, 0.817815f, 1.208726f, -1.411448f, 0.606279f, -0.143777f, 0.296987f, 1.422581f, 0.720905f, + 1.279913f, -0.352711f, 0.658642f, 1.613478f, 0.339589f, -0.089663f, 0.243404f, 1.226488f, 0.467706f, + 0.797042f, 0.442854f, 1.121590f, -0.153407f, 1.431477f, 0.230959f, 1.437285f, -0.046937f, -1.527740f, + -0.272532f, 0.732910f, 0.766692f, 0.749836f, 0.778544f, 1.502128f, -0.240678f, 0.820989f, 1.461264f, + 0.744201f, 0.593997f, 0.769196f, 0.670758f, -0.186752f, 1.864102f, -0.563369f, 2.274148f, 1.338321f, + 0.830787f, -0.191057f, 0.642745f, 1.092864f, 1.217034f, 1.076530f, 0.948315f}; + + std::vector swish_data = { + 0.243866f, -0.159901f, -0.192410f, 0.064602f, 1.365124f, 0.309820f, 2.455177f, 1.225849f, 0.050220f, + -0.020019f, -0.054734f, 0.294918f, 0.657257f, 0.561637f, 0.646839f, 0.735547f, -0.177689f, -0.181556f, + 0.132996f, 0.133336f, 1.365588f, 2.420778f, 1.802949f, 1.254788f, 0.618877f, 0.105460f, 0.146142f, + 0.015386f, 0.492453f, 0.752824f, 0.352078f, 0.596943f, 0.510088f, 0.519554f, 0.509331f, 0.492345f, + 0.537078f, 0.823517f, 0.856461f, 0.581139f, 0.379677f, -0.029484f, 0.108424f, 0.186914f, 1.201586f, + 0.040870f, 0.392432f, 1.487454f, 0.100271f, 0.325333f, -0.278360f, 1.006534f, 0.029728f, 0.022692f, + -0.001154f, -0.005585f, -0.182035f, 0.603779f, 1.587304f, -0.001581f, 0.077149f, 0.019369f, 0.045678f, + 0.081065f, 0.357653f, 0.096304f, 0.327438f, 0.145732f, 0.016068f, 0.479364f, -0.067726f, 1.000125f, + 0.258221f, 0.425634f, -0.267492f, 0.206097f, 0.062951f, -0.126475f, -0.128642f, 0.149164f, 0.028085f, + 0.768537f, 0.448763f, 0.134332f, 0.331049f, 0.260306f, -0.061851f, -0.187918f, -0.251691f, -0.187456f, + 0.384811f, 1.363060f, -0.092633f, 0.928184f, 0.401312f, 0.694153f, 0.412580f, 0.475279f, 0.435012f, + 0.461047f, 0.762192f, 0.858428f, 0.522193f, 0.536204f, -0.013831f, 0.740447f, 1.305406f, 0.450521f, + 0.062333f, 0.062195f, 0.062175f, 0.062186f, -0.263020f, 0.602277f, -0.181835f, 0.055983f, 0.143483f, + 0.102570f, -0.167443f, 0.498477f, 0.733510f, -0.234669f, -0.190078f, 0.028087f, 0.029050f, 0.035395f, + -0.018481f, 0.003598f, 0.313013f, 0.110958f, 0.272698f, 0.016046f, 0.102139f, 0.141960f, 0.160295f, + 0.029994f, -0.015817f, -0.078762f, -0.009012f, 0.021779f, -0.090723f, 0.036662f, 0.653681f, 0.264239f, + 0.454239f, 0.950773f, 0.929582f, -0.033536f, 0.640686f, 1.071117f, -0.265990f, 0.872580f, 0.067457f, + 0.094387f, 0.746799f, -0.122234f, 0.799871f, 1.398649f, 0.879794f, 0.295709f, -0.008629f, 0.039996f, + 0.009796f, -0.056175f, 0.691892f, 0.914138f, 0.475701f, 1.027117f, 0.476392f, 0.049258f, 0.479070f, + -0.266875f, 0.976283f, 0.902623f, 1.076367f, -0.084465f, -0.217050f, -0.178574f, 0.153117f, 0.198578f, + 0.624266f, 0.579420f, 0.587422f, 0.650082f, 0.734605f, 0.830635f, 0.609541f, 0.757945f, 0.229296f, + 0.140073f, 0.054136f, 0.173615f, 0.564844f, 0.577730f, 0.549380f, 0.576280f, -0.253452f, 1.082880f, + 0.326523f, -0.204651f, 0.757935f, 0.406556f, 0.847322f, 1.137731f, 1.399714f, 0.741494f, -0.204150f, + 1.156216f, 0.714629f, -0.219945f, 1.050518f, 0.391983f, 0.080910f, 0.293266f, 0.073575f, 0.109964f, + -0.178318f, 0.017723f, 1.648380f, -0.173044f, -0.006105f, 1.507298f, -0.206833f, 1.268957f, 0.820346f, + 0.435470f, 0.600764f, 0.618676f, 0.640120f, 0.503657f, 0.602378f, 0.527688f, -0.171787f, 0.547762f, + 0.716126f, -0.255739f, 0.657118f, 0.814111f, 0.711085f, 0.731079f, -0.213635f, 0.172503f, 0.972323f, + 0.048719f, 0.238256f, 0.294135f, 0.375335f, 0.275437f, 0.850725f, 0.850672f, 0.471277f, 0.615220f, + 1.111010f, 0.277405f, -0.097483f, 1.936515f, 0.501217f, 0.861257f, -0.165546f, 0.015392f, 0.206920f, + -0.206513f, 1.120330f, -0.095172f, 0.581032f, 0.445763f, 0.359888f, 0.460849f, 0.398530f, 0.882337f, + 0.373787f, 0.479997f, 0.457656f, 0.577514f, 0.636029f, 0.508724f, 0.408295f, 0.086886f, 0.450923f, + -0.108577f, 0.359337f, 0.256654f, 0.269234f, 0.345855f, 0.245632f, 0.314115f, 0.341984f, 0.331264f, + 0.428173f, 0.855378f, 0.057805f, 0.699551f, 0.177226f, 0.237559f, 0.156379f, 0.238672f, 0.881711f, + 0.900973f, 0.056874f, 0.559207f, 0.212098f, 0.633758f, 0.207681f, 0.247724f, 0.378743f, 0.870852f, + 0.474008f, 0.603606f, 0.644037f, 0.002510f, 0.653584f, 0.871938f, 0.252370f, 0.373285f, 0.410021f, + 1.000926f, 0.254082f, 0.897667f, 0.208764f, 0.133824f, 0.307957f, 0.470606f, 0.638058f, 0.372154f, + 0.549116f, 0.673480f, 0.661132f, 0.840444f, 0.517441f, 0.785239f, 0.710716f, 0.780221f, 0.727826f, + 0.770623f, 0.671878f, 0.674734f, -0.137456f, 0.249797f, -0.151240f, -0.218730f, 1.370297f, 1.283304f, + -0.000166f, 0.547163f, -0.221845f, -0.253514f, -0.178658f, -0.185949f, 1.072585f, 0.863022f, 0.502916f, + 1.354473f, 0.632512f, 0.502989f, 0.415034f, 0.542996f, 0.387934f, 0.360029f, 0.649641f, 0.508608f, + 0.548196f, -0.038882f, 0.786735f, -0.011991f, 0.076070f, 1.912126f, 1.076488f, 0.230168f, -0.007803f, + 0.287736f, 0.889679f, 0.055314f, -0.256636f, 0.192089f, 0.241274f, -0.264336f, 0.768393f, 0.558143f, + 0.743397f, 0.433681f, 0.147922f, 0.998140f, 0.546106f, 0.713642f, 0.904965f, -0.044974f, 0.393839f, + 1.226827f, 0.222318f, -0.199037f, -0.169319f, -0.035779f, 0.306135f, 0.259179f, 0.236975f, 0.245986f, + -0.052967f, 0.197649f, -0.045887f, 0.969103f, 0.073060f, -0.126316f, 1.027756f, -0.094573f, 0.947734f, + -0.258402f, 0.494719f, 0.103365f, 0.085489f, -0.034082f, -0.032942f, 0.050215f, 0.866159f, 0.522367f, + 0.296938f, 0.516856f, -0.126290f, -0.145276f, -0.040656f, 0.383809f, 0.127472f, -0.032544f, 0.589695f, + 0.432058f, 0.615866f, 0.843271f, 0.848825f, 0.644802f, 0.237987f, 0.040511f, -0.052676f, -0.139653f, + 0.262488f, 0.272236f, 0.441086f, 0.169732f, 1.561563f, -0.112454f, 0.744888f, 0.362299f, 0.207543f, + 0.142219f, 0.943881f, 0.038345f, -0.003415f, 0.961279f, 0.888123f, 0.204724f, 0.141699f, -0.267405f, + -0.270732f, 0.907438f, 0.200946f, 0.573859f, 1.086931f, 0.013716f, -0.188076f, -0.178851f, 1.412803f, + 2.144154f, 0.227198f, 0.066942f, 0.381228f, 1.195574f, 0.516802f, -0.093427f, 0.769628f, -0.171073f, + 0.062015f, 1.220799f, 0.200465f, 1.521402f, 0.200143f, 0.051105f, 2.362491f, -0.233436f, 0.632902f, + 0.262543f, 0.711443f, 0.370009f, -0.125327f, 0.610777f, 0.770470f, 0.665922f, -0.277876f, -0.197959f, + -0.049002f, -0.202732f, 0.528298f, 0.498287f, 0.519488f, 0.536225f, 1.722697f, 0.191122f, -0.092087f, + 1.927784f, 0.558246f, 0.568889f, 0.427906f, 0.518755f, -0.202095f, 0.031705f, 1.368965f, -0.254002f, + 1.226985f, 0.380467f, 0.260506f, 0.083084f, 0.907526f, 0.678707f, 0.841215f, 0.872584f, 0.550561f, + -0.104546f, -0.080327f, 0.141080f, 0.499761f, 2.138265f, 0.384000f, 0.602158f, 0.195041f, 0.028526f, + 0.575932f, -0.143482f, 0.621209f, 0.643413f, 0.605480f, 0.635026f, 0.344486f, 0.059589f, -0.234058f, + 0.627989f, 0.962738f, -0.257335f, 1.135901f, 0.197799f, -0.055014f, 0.312156f, 0.498675f, 0.432245f, + 0.451459f, 0.625349f, 0.633730f, 0.754035f, 0.508984f, 0.574370f, 0.548760f, 0.462362f, 0.643412f, + 0.837068f, 0.655944f, 0.434440f, 0.327522f, 0.578454f, 0.010501f, 0.645105f, 0.022782f, 0.602389f, + 0.332705f, 0.124847f, 0.415858f, -0.244295f, 0.636554f, 0.210367f, 0.053347f, 0.024967f, -0.011600f, + -0.003652f, 0.883625f, 0.094167f, 0.005264f, -0.157717f, -0.083251f, 0.221779f, -0.029940f, 0.177076f, + 0.221916f, 0.317751f, 0.314914f, 0.365097f, 0.425394f, 0.632969f, 0.490896f, 0.265550f, 0.154676f, + 0.151727f, -0.263180f, -0.131768f, 0.123024f, -0.189286f, -0.093822f, -0.109161f, 0.294614f, -0.098691f, + 0.537700f, 0.376140f, 0.168252f, 0.091604f, -0.161157f, -0.159695f, 0.546489f, 0.564490f, 1.395845f, + 1.104433f, 0.263568f, 0.686118f, 0.238550f, -0.150630f, 0.456214f, 0.410026f, 0.506708f, 0.424806f, + 0.687772f, 0.894037f, 0.514549f, 0.560069f, -0.082351f, 0.866797f, 0.801729f, -0.178628f, 0.062297f, + 0.062251f, 0.062210f, 0.062287f, 0.929695f, -0.152669f, 0.298229f, 0.277207f, -0.036025f, -0.121153f, + 0.064261f, -0.075345f, -0.162895f, 1.069637f, -0.046150f, 0.355371f, -0.014807f, 0.044176f, 0.013210f, + -0.026496f, 0.024873f, 0.359187f, 0.359935f, 0.217098f, 0.034856f, 0.058271f, 0.127668f, 0.137113f, + 0.000690f, -0.021974f, -0.044238f, -0.079923f, -0.029181f, 0.619223f, 0.480672f, -0.002627f, 0.159995f, + 0.751405f, 0.924329f, 0.999099f, -0.254131f, 1.416720f, -0.222613f, 0.285733f, 0.566570f, 1.349653f, + 0.698375f, 0.561619f, 0.455867f, 0.990985f, -0.125735f, -0.199164f, 0.079798f, 0.003070f, -0.023112f, + -0.017134f, 0.950309f, 0.523259f, 0.913967f, 0.578059f, 0.935859f, -0.262766f, -0.074537f, -0.096513f, + 0.043096f, 1.266156f, 1.180120f, 0.286938f, -0.227895f, -0.159335f, -0.102207f, 0.180099f, 0.595564f, + 0.636225f, 0.608330f, 0.645301f, 0.851290f, 0.654005f, 0.751979f, 0.812592f, 0.203369f, 0.060952f, + 0.153044f, 0.009046f, 0.584960f, 0.550860f, 0.578823f, 0.589835f, 0.329856f, -0.135870f, 0.549166f, + 0.624253f, 0.829387f, 0.474291f, 1.010328f, 0.793519f, -0.242043f, 0.975459f, -0.120458f, -0.249415f, + 0.415417f, 1.058707f, 1.186345f, -0.203734f, 0.202362f, 0.231365f, 0.005587f, 0.407439f, -0.091046f, + -0.265132f, 0.436457f, 1.011931f, -0.218020f, -0.138064f, 0.224131f, 1.116821f, 0.682626f, 0.361950f, + 0.385073f, 0.542856f, 0.833448f, 0.721125f, 0.636303f, 0.705290f, 0.276201f, -0.084439f, 0.608590f, + 0.739027f, 0.786472f, 0.735919f, 0.582164f, 0.623249f, -0.143303f, 0.340258f, -0.106517f, 0.522368f, + 0.214515f, 0.341388f, 0.303639f, 0.353579f, 0.544456f, 0.697275f, 0.640568f, 0.995898f, 1.164481f, + 0.418773f, -0.243616f, 1.767281f, -0.092349f, 0.283780f, 0.633947f, -0.081556f, -0.155917f, 0.466470f, + 0.432398f, -0.076679f, 0.424301f, 0.337023f, 0.956541f, 0.835456f, 0.993236f, 0.517416f, 0.127733f, + 0.079021f, 0.547438f, 0.824758f, 0.838249f, 0.785873f, -0.093257f, 0.182914f, -0.015911f, -0.053940f, + 0.341603f, 0.310421f, 0.285687f, 0.252067f, 0.298046f, 0.271221f, 0.277146f, 0.266335f, -0.003180f, + 0.224097f, -0.255225f, 0.770431f, 0.245189f, 0.198482f, 0.164428f, 0.254018f, 0.285878f, 0.738142f, + 0.732134f, 0.352326f, 0.281830f, 0.867442f, 0.197982f, 0.479874f, 0.460745f, 0.496878f, 0.781012f, + 0.533761f, 0.035461f, 0.155663f, 1.207408f, 1.430939f, 0.330722f, 0.648210f, 0.205636f, 0.689173f, + 0.108188f, 0.660919f, 0.023088f, 0.691594f, 0.479376f, 0.413581f, 0.542946f, 0.372724f, 0.591785f, + 0.577837f, 0.916584f, 0.704632f, 0.748902f, 0.646659f, 0.688003f, 0.686755f, 0.827456f, 0.747968f, + 0.642034f, 0.588283f, -0.255522f, -0.138260f, -0.180515f, 0.285072f, 0.839288f, -0.269231f, 1.353391f, + 0.191627f, 0.074588f, -0.180937f, 0.178024f, 0.745949f, 0.277417f, 1.326083f, 0.355219f, 0.752707f, + 0.388207f, 0.737830f, 0.727050f, 0.417238f, 0.703465f, 0.596488f, 0.373560f, 0.587475f, 0.262941f, + -0.004666f, -0.159025f, 0.127896f, 0.902279f, 0.120832f, 0.077809f, 1.346089f, 0.740486f, 0.628741f, + -0.269769f, 0.149638f, 0.082008f, 0.364801f, 0.662657f, -0.116884f, 0.604751f, 0.779395f, 0.738113f, + 0.677536f, 0.302422f, 0.137584f, 0.126410f, 0.034505f, -0.161350f, 0.986884f, 0.145023f, -0.174461f, + 0.306618f, -0.116360f, -0.097077f, -0.128073f, 0.293111f, 0.221499f, 0.272082f, 0.290052f, 0.437553f, + 0.731756f, -0.043221f, 0.446063f, 0.536501f, 0.068210f, 0.961003f, 0.521620f, 1.002620f, 0.641700f, + 0.158794f, 0.122747f, 0.086627f, -0.050289f, 0.116380f, 0.075711f, 0.108969f, 0.080593f, 0.360497f, + 0.025843f, 0.629911f, 0.038555f, 0.041430f, -0.149042f, 0.142787f, -0.131760f, -0.124825f, 0.070983f, + 0.823012f, 0.696361f, 0.549003f, 0.597205f, 0.409770f, 0.408138f, 0.424474f, 0.074123f, 0.318761f, + 1.117023f, 0.387609f, 0.968585f, 0.615761f, 0.156582f, -0.190899f, -0.163160f, 1.036466f, 0.204659f, + 1.273897f, 0.082720f, 0.129264f, 0.232396f, 0.224349f, 0.586964f, -0.251066f, 1.530559f, 0.054939f, + 0.098779f, 0.441357f, 0.362511f, 0.483341f, 0.137334f, 0.853822f, 0.623333f, 1.185376f, 1.757106f, + -0.159001f, 0.567378f, 0.930808f, -0.276652f, 0.392318f, -0.066730f, 0.170383f, 1.146234f, 0.485029f, + 1.001448f, -0.145573f, 0.434016f, 1.345469f, 0.198351f, -0.042823f, 0.136440f, 0.948325f, 0.287564f, + 0.549434f, 0.269671f, 0.845997f, -0.070832f, 1.155390f, 0.128756f, 1.161375f, -0.022918f, -0.272434f, + -0.117812f, 0.495040f, 0.523501f, 0.509246f, 0.533588f, 1.228578f, -0.105927f, 0.570132f, 1.186146f, + 0.504504f, 0.382702f, 0.525628f, 0.443822f, -0.084682f, 1.613891f, -0.204372f, 2.062000f, 1.060236f, + 0.578661f, -0.086430f, 0.421238f, 0.818468f, 0.938992f, 0.802915f, 0.683523f}; + + // Test float16, without activation + int min_cuda_architecture = 530; + if (HasCudaEnvironment(min_cuda_architecture)) { + OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 32); + test.AddAttribute("activation", 0); + + test.AddInput("X", dims, ToFloat16(input_data)); + test.AddInput("gamma", {C}, gamma_data); + + test.AddInput("beta", {C}, beta_data); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims, ToFloat16(norm_data), false, rel_error, abs_error); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + + // Test float16, with activation + if (HasCudaEnvironment(min_cuda_architecture)) { + OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 32); + test.AddAttribute("activation", 1); + + test.AddInput("X", dims, ToFloat16(input_data)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims, ToFloat16(swish_data), false, rel_error, abs_error); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +} // namespace test +} // namespace onnxruntime From 4a7bf0d8ff0d0599ff6b5826a34d29fd2b248523 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Jan 2023 23:48:43 +0000 Subject: [PATCH 07/27] Add SplitGelu operator --- .../cpu/transformers/generation_shared.h | 29 ++++++- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 6 +- .../cuda/diffusion/group_norm_impl.cu | 11 ++- .../contrib_ops/cuda/diffusion/split_gelu.cc | 59 ++++++++++++++ .../contrib_ops/cuda/diffusion/split_gelu.h | 23 ++++++ .../cuda/diffusion/split_gelu_impl.cu | 81 +++++++++++++++++++ .../cuda/diffusion/split_gelu_impl.h | 20 +++++ .../cuda/transformers/dump_cuda_tensor.cc | 2 +- .../core/graph/contrib_ops/diffusion_defs.cc | 41 +++++++++- onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + 10 files changed, 263 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/split_gelu.cc create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/split_gelu.h create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.h diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cf1d99688546a..37f4513e137d5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -9,10 +9,6 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" -#ifndef NDEBUG -//#define DEBUG_GENERATION 1 // uncomment it for debugging beam search -#endif - namespace onnxruntime { namespace concurrency { @@ -167,6 +163,31 @@ struct IGenerationParameters { bool custom_sampling = false; }; + +#ifndef NDEBUG +//#define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) +#endif + +#ifdef DEBUG_GENERATION +#define DUMP_TENSOR_LEVEL 2 +#else +#define DUMP_TENSOR_LEVEL 1 // change it to 0 if want to disable dumping for code not in generation. +#endif + +#if DUMP_TENSOR_LEVEL > 0 +#define DUMP_TENSOR_INIT() transformers::CudaTensorConsoleDumper dumper +#define DUMP_TENSOR(...) dumper.Print(__VA_ARGS__) +#else +#define DUMP_TENSOR_INIT() +#define DUMP_TENSOR(...) +#endif +#if DUMP_TENSOR_LEVEL > 1 +#define DUMP_TENSOR_D(...) dumper.Print(__VA_ARGS__) +#else +#define DUMP_TENSOR_D(...) +#endif + + class IConsoleDumper { public: IConsoleDumper() : is_enabled_(true) {} diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index ad78e1bd8e960..9808c1862a7fb 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -89,6 +89,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SplitGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); @@ -196,7 +198,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -212,6 +214,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 7108d6e03d60f..01ba078b4be77 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -21,6 +21,7 @@ #include #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" namespace onnxruntime { namespace contrib { @@ -39,7 +40,7 @@ struct GroupSums { int32_t flag; // The sum. float sum; - // The sum of squares. + // The sum of squares. float sumSq; }; @@ -446,13 +447,17 @@ Status LaunchGroupNormKernel( params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); params.groupsPerBlock = cPerBlock / params.cPerGroup; + DUMP_TENSOR_INIT(); + DUMP_TENSOR("input", input, batch_size, num_channels, height * width); + DUMP_TENSOR("gamma", gamma, 1, num_channels); + DUMP_TENSOR("beta", beta, 1, num_channels); cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); groupNormNHWCSum(params, stream); + DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - + DUMP_TENSOR("output", output, batch_size, num_channels, height * width); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.cc new file mode 100644 index 0000000000000..d5b36b2fe990a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/split_gelu.h" +#include "contrib_ops/cuda/diffusion/split_gelu_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SplitGelu, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SplitGelu); + +REGISTER_KERNEL_TYPED(MLFloat16); +REGISTER_KERNEL_TYPED(float); + +using namespace ONNX_NAMESPACE; + +template +SplitGelu::SplitGelu(const OpKernelInfo& op_info) : CudaKernel(op_info) { +} + +template +Status SplitGelu::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 dimensions, got ", input_dims.size()); + } + + TensorShapeVector output_shape = input->Shape().AsShapeVector(); + output_shape[2] = input_dims[2] / 2; + Tensor* output = context->Output(0, output_shape); + + typedef typename ToCudaType::MappedType CudaT; + const int32_t grid_size = static_cast(input_dims[0] * input_dims[1]); + const int32_t half_hidden_size = static_cast(input_dims[2] / 2); + LaunchSplitGeluKernel(Stream(context), grid_size, half_hidden_size, + reinterpret_cast(input->Data()), + reinterpret_cast(output->MutableData())); + + CUDA_RETURN_IF_ERROR(cudaPeekAtLastError()); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.h b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.h new file mode 100644 index 0000000000000..547eb65b3ad01 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class SplitGelu final : public CudaKernel { + public: + SplitGelu(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.cu new file mode 100644 index 0000000000000..dad3ff4c243ea --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.cu @@ -0,0 +1,81 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from SplitGelu plugin of TensorRT 8.5 +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/split_gelu_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void splitGeluKernel(T const* input, T* output) { + int32_t index_input = blockIdx.x * HHS * 2 + threadIdx.x; + int32_t index_output = blockIdx.x * HHS + threadIdx.x; + +#pragma unroll + for (int32_t i = 0; i < HHS / TPB; ++i) { + auto value_left = static_cast(input[index_input]); + auto value_right = static_cast(input[index_input + HHS]); + + // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / 1.41421356237) + 1.0) + float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); + float result = value_left * gelu_right; + output[index_output] = static_cast(result); + index_input += TPB; + index_output += TPB; + } + return; +} + +template +void LaunchSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, T const* input, T* output) { + constexpr int32_t TPB = 256; // thread per block + switch (half_hidden_size) { + case 1280: + (splitGeluKernel)<<>>(input, output); + break; + case 2560: + (splitGeluKernel)<<>>(input, output); + break; + case 5120: + (splitGeluKernel)<<>>(input, output); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template __global__ void splitGeluKernel(float const*, float*); +template __global__ void splitGeluKernel(float const*, float*); +template __global__ void splitGeluKernel(float const*, float*); +template __global__ void splitGeluKernel(half const*, half*); +template __global__ void splitGeluKernel(half const*, half*); +template __global__ void splitGeluKernel(half const*, half*); + +template void LaunchSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + float const* input, float* output); + +template void LaunchSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + half const* input, half* output); +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.h new file mode 100644 index 0000000000000..d83e59b595b64 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/common/status.h" +#include +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +void LaunchSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, T const* input, T* output); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index 6c0f7f69c58a1..741f9ac259da1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -11,7 +11,7 @@ namespace contrib { namespace cuda { namespace transformers { -#ifdef DEBUG_GENERATION +#if DUMP_TENSOR_LEVEL > 0 template class PinnedHostBuffer { public: diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index 60f7538ef7056..c96caf24d9df2 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -15,6 +15,7 @@ namespace onnxruntime { namespace contrib { using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::TensorShapeProto; using ONNX_NAMESPACE::OpSchema; #ifndef NDEBUG using ONNX_NAMESPACE::DbgOperatorSetTracker; @@ -59,10 +60,46 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Y", "The output tensor of the same shape as X", "T") - .TypeConstraint("T", {"tensor(float16)"}, "Constrain input X and output Y types to half tensors.") - //.TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + +constexpr const char* SplitGelu_ver1_doc = R"DOC( +A fusion used in diffusion model that hidden state is sliced into two parts, one part applied Gelu actication, then these +two parts are multiplied. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + SplitGelu, 1, + OpSchema() + .SetDoc(SplitGelu_ver1_doc) + .Input(0, + "X", + "Input data tensor. Dimensions are (N, H*W, D), where N is the batch size, H and W are the height and width of the data, and D is hidden dimension", + "T") + .Output(0, + "Y", + "The output tensor with dimensions (N, H*W, D/2)", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to half tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (hasInputShape(ctx, 0)) { + auto& input_shape = getInputShape(ctx, 0); + if (input_shape.dim().size() != 3) { + fail_shape_inference("input shall be 3 dimensions"); + } + if (input_shape.dim(2).has_dim_value()) { + fail_shape_inference("input dim 2 shall have dim value"); + } + + TensorShapeProto output_shape; + *output_shape.add_dim() = input_shape.dim(0); + *output_shape.add_dim() = input_shape.dim(1); + output_shape.add_dim()->set_dim_value(input_shape.dim(2).dim_value() / 2); + updateOutputShape(ctx, 0, output_shape); + } + })); } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index d82da5e65db71..4ac58823c1050 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -88,6 +88,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SplitGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TorchEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TransposeMatMul); @@ -177,6 +178,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); From 98b90ca58bc78d46cd0b82905435da9ab4b43bd1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 28 Jan 2023 00:06:20 +0000 Subject: [PATCH 08/27] format --- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 +- .../contrib_ops/cuda/diffusion/group_norm.cc | 8 ++--- .../cuda/diffusion/group_norm_impl.h | 33 +++++++++---------- .../core/graph/contrib_ops/diffusion_defs.cc | 3 +- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 9808c1862a7fb..4e39ad0efd44f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -198,7 +198,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 594d78907a769..da1e8f70c0409 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -37,7 +37,7 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { int64_t activation; ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); - ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish + ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish use_swish_activation_ = (activation == 1); } @@ -67,11 +67,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const auto& beta_dims = beta->Shape().GetDims(); if (beta_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); + "beta is expected to have 1 dimension, got ", beta_dims.size()); } if (beta_dims[0] != input_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in beta and input does not match"); + "Number of channels in beta and input does not match"); } int batch_size = static_cast(input_dims[0]); @@ -81,7 +81,7 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { if (num_channels % num_groups_ != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); + "number of channels should be divisiable by num_groups"); } auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index 347c4624ac9f1..c7e9245050ee6 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -15,28 +15,27 @@ namespace cuda { constexpr size_t kMaxGroupNormBatchSize = 32; constexpr size_t kGroupNormNumberOfGroups = 32; -constexpr size_t GetGroupNormWorkspaceSizeInBytes() -{ - // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; +constexpr size_t GetGroupNormWorkspaceSizeInBytes() { + // Two buffers for sum and squared sum + return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; } template Status LaunchGroupNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization - ); + T* output, // normalized output tensor + const T* input, // input tensor + const float* gamma, // gamma (also known as weight or scale) + const float* beta, // beta (also known as bias) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_swish_activation // Whether there is Swish activation after group normalization +); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c96caf24d9df2..6bd1330d793ae 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -15,8 +15,8 @@ namespace onnxruntime { namespace contrib { using ONNX_NAMESPACE::AttributeProto; -using ONNX_NAMESPACE::TensorShapeProto; using ONNX_NAMESPACE::OpSchema; +using ONNX_NAMESPACE::TensorShapeProto; #ifndef NDEBUG using ONNX_NAMESPACE::DbgOperatorSetTracker; #endif @@ -64,7 +64,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); - constexpr const char* SplitGelu_ver1_doc = R"DOC( A fusion used in diffusion model that hidden state is sliced into two parts, one part applied Gelu actication, then these two parts are multiplied. From ea69aec90f9a6f309cb03a54be39ca210861bf8e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 28 Jan 2023 00:07:29 +0000 Subject: [PATCH 09/27] format --- onnxruntime/python/tools/transformers/fusion_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index b5e390d835b18..d84c433687b5e 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -28,7 +28,7 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}") return False, input_name - def cast_input(self, input_name: str, target_type = "int32"): + def cast_input(self, input_name: str, target_type="int32"): cast_output = input_name + "_" + target_type # Avoid consequent Cast nodes. @@ -56,7 +56,6 @@ def cast_input(self, input_name: str, target_type = "int32"): return cast_output, cast_node - def cast_input_to_int32(self, input_name: str): return self.cast_input(input_name, "int32") From 9eacd8430723f725c02b66318bb2abf8492f2a4e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 29 Jan 2023 10:43:07 +0000 Subject: [PATCH 10/27] misc --- .../cpu/transformers/generation_shared.h | 2 +- .../core/graph/contrib_ops/diffusion_defs.cc | 10 ++++---- .../tools/transformers/fusion_group_norm.py | 24 +++++++++++++++++-- .../tools/transformers/fusion_options.py | 2 +- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 37f4513e137d5..7b641cbef046a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -171,7 +171,7 @@ struct IGenerationParameters { #ifdef DEBUG_GENERATION #define DUMP_TENSOR_LEVEL 2 #else -#define DUMP_TENSOR_LEVEL 1 // change it to 0 if want to disable dumping for code not in generation. +#define DUMP_TENSOR_LEVEL 0 // change it to 0 if want to disable dumping for code not in generation. #endif #if DUMP_TENSOR_LEVEL > 0 diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index 6bd1330d793ae..ce2ba8ab42f95 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -89,14 +89,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( if (input_shape.dim().size() != 3) { fail_shape_inference("input shall be 3 dimensions"); } - if (input_shape.dim(2).has_dim_value()) { - fail_shape_inference("input dim 2 shall have dim value"); - } TensorShapeProto output_shape; *output_shape.add_dim() = input_shape.dim(0); *output_shape.add_dim() = input_shape.dim(1); - output_shape.add_dim()->set_dim_value(input_shape.dim(2).dim_value() / 2); + if (input_shape.dim(2).has_dim_value()) { + output_shape.add_dim()->set_dim_value(input_shape.dim(2).dim_value() / 2); + } else { + output_shape.add_dim(); + } + updateOutputShape(ctx, 0, output_shape); } })); diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index 6f0874bc1c450..d676a53492af2 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -154,10 +154,17 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.FLOAT))]) self.model.add_node(cast_node) + # NCHW to NHWC + transpose_input = helper.make_node( + "Transpose", [input], [input + "_NHWC"], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"), + perm=[0, 2, 3, 1] + ) + new_node = helper.make_node( "GroupNorm", - inputs=[input, group_norm_name + "_gamma", group_norm_name + "_beta"], - outputs=[output], + inputs=[input + "_NHWC", group_norm_name + "_gamma", group_norm_name + "_beta"], + outputs=[output + "_NHWC"], name=group_norm_name, ) @@ -165,6 +172,19 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): new_node.attribute.extend([helper.make_attribute("groups", 32)]) new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)]) new_node.domain = "com.microsoft" + + # NHWC to NCHW + transpose_output = helper.make_node( + "Transpose", [output + "_NHWC"], [output], + name = self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"), + perm=[0, 3, 1, 2] + ) + self.nodes_to_add.append(new_node) + self.nodes_to_add.append(transpose_input) + self.nodes_to_add.append(transpose_output) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name + self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name return True diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 17306762c118e..18765c6021f68 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -36,7 +36,7 @@ def __init__(self, model_type): self.enable_shape_inference = True self.enable_gemm_fast_gelu = False - self.enable_group_norm = model_type == "unet" + self.enable_group_norm = False self.enable_splitgelu = model_type == "unet" self.attention_mask_format = AttentionMaskFormat.AttentionMask From c566679adad186661319101eb0cf3e2df5559070 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 29 Jan 2023 11:41:04 +0000 Subject: [PATCH 11/27] update group norm test data to NHWC --- .../contrib_ops/cuda/diffusion/group_norm.cc | 11 +- .../tools/transformers/fusion_options.py | 2 +- .../test/contrib_ops/group_norm_op_test.cc | 697 +++++++++--------- 3 files changed, 357 insertions(+), 353 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index da1e8f70c0409..088b0c9be5a05 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -59,7 +59,7 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[1]) { + if (gamma_dims[0] != input_dims[3]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -69,15 +69,16 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[1]) { + if (beta_dims[0] != input_dims[3]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } + // Input and output format is NHWC int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[1]); - int height = static_cast(input_dims[2]); - int width = static_cast(input_dims[3]); + int num_channels = static_cast(input_dims[3]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); if (num_channels % num_groups_ != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 18765c6021f68..17306762c118e 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -36,7 +36,7 @@ def __init__(self, model_type): self.enable_shape_inference = True self.enable_gemm_fast_gelu = False - self.enable_group_norm = False + self.enable_group_norm = model_type == "unet" self.enable_splitgelu = model_type == "unet" self.attention_mask_format = AttentionMaskFormat.AttentionMask diff --git a/onnxruntime/test/contrib_ops/group_norm_op_test.cc b/onnxruntime/test/contrib_ops/group_norm_op_test.cc index c3b43f708ccfa..4af51e24159ef 100644 --- a/onnxruntime/test/contrib_ops/group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_norm_op_test.cc @@ -21,111 +21,111 @@ TEST(GroupNormTest, GroupNorm_128) { constexpr int64_t H = 2; constexpr int64_t W = 2; - std::vector dims{B, C, H, W}; + std::vector dims{B, H, W, C}; std::vector input_data = { - 0.696469f, 0.286139f, 0.226851f, 0.551315f, 0.719469f, 0.423106f, 0.980764f, 0.684830f, 0.480932f, 0.392118f, - 0.343178f, 0.729050f, 0.438572f, 0.059678f, 0.398044f, 0.737995f, 0.182492f, 0.175452f, 0.531551f, 0.531828f, - 0.634401f, 0.849432f, 0.724455f, 0.611024f, 0.722443f, 0.322959f, 0.361789f, 0.228263f, 0.293714f, 0.630976f, - 0.092105f, 0.433701f, 0.430863f, 0.493685f, 0.425830f, 0.312261f, 0.426351f, 0.893389f, 0.944160f, 0.501837f, - 0.623953f, 0.115618f, 0.317285f, 0.414826f, 0.866309f, 0.250455f, 0.483034f, 0.985560f, 0.519485f, 0.612895f, - 0.120629f, 0.826341f, 0.603060f, 0.545068f, 0.342764f, 0.304121f, 0.417022f, 0.681301f, 0.875457f, 0.510422f, - 0.669314f, 0.585937f, 0.624904f, 0.674689f, 0.842342f, 0.083195f, 0.763683f, 0.243666f, 0.194223f, 0.572457f, - 0.095713f, 0.885327f, 0.627249f, 0.723416f, 0.016129f, 0.594432f, 0.556785f, 0.158960f, 0.153071f, 0.695530f, - 0.318766f, 0.691970f, 0.554383f, 0.388951f, 0.925132f, 0.841670f, 0.357398f, 0.043591f, 0.304768f, 0.398186f, - 0.704959f, 0.995358f, 0.355915f, 0.762548f, 0.593177f, 0.691702f, 0.151127f, 0.398876f, 0.240856f, 0.343456f, - 0.513128f, 0.666625f, 0.105908f, 0.130895f, 0.321981f, 0.661564f, 0.846506f, 0.553257f, 0.854452f, 0.384838f, - 0.316788f, 0.354265f, 0.171082f, 0.829113f, 0.338671f, 0.552370f, 0.578551f, 0.521533f, 0.002688f, 0.988345f, - 0.905342f, 0.207636f, 0.292489f, 0.520010f, 0.901911f, 0.983631f, 0.257542f, 0.564359f, 0.806969f, 0.394370f, - 0.731073f, 0.161069f, 0.600699f, 0.865864f, 0.983522f, 0.079366f, 0.428347f, 0.204543f, 0.450636f, 0.547764f, - 0.093327f, 0.296861f, 0.927584f, 0.569004f, 0.457412f, 0.753526f, 0.741862f, 0.048579f, 0.708697f, 0.839243f, - 0.165938f, 0.780998f, 0.286537f, 0.306470f, 0.665261f, 0.111392f, 0.664872f, 0.887857f, 0.696311f, 0.440328f, - 0.438214f, 0.765096f, 0.565642f, 0.084904f, 0.582671f, 0.814844f, 0.337066f, 0.927577f, 0.750717f, 0.574064f, - 0.751644f, 0.079149f, 0.859389f, 0.821504f, 0.909872f, 0.128631f, 0.081780f, 0.138416f, 0.399379f, 0.424307f, - 0.562218f, 0.122244f, 0.201400f, 0.811644f, 0.467988f, 0.807938f, 0.007426f, 0.551593f, 0.931932f, 0.582175f, - 0.206096f, 0.717758f, 0.378986f, 0.668384f, 0.029320f, 0.635900f, 0.032198f, 0.744781f, 0.472913f, 0.121754f, - 0.542636f, 0.066774f, 0.653365f, 0.996086f, 0.769397f, 0.573774f, 0.102635f, 0.699834f, 0.661168f, 0.049097f, - 0.792299f, 0.518717f, 0.425868f, 0.788187f, 0.411569f, 0.481026f, 0.181629f, 0.321319f, 0.845533f, 0.186904f, - 0.417291f, 0.989035f, 0.236600f, 0.916832f, 0.918397f, 0.091296f, 0.463653f, 0.502216f, 0.313669f, 0.047340f, - 0.241686f, 0.095530f, 0.238250f, 0.807791f, 0.894978f, 0.043223f, 0.301947f, 0.980582f, 0.539505f, 0.626309f, - 0.005545f, 0.484909f, 0.988329f, 0.375186f, 0.097038f, 0.461909f, 0.963004f, 0.341831f, 0.798923f, 0.798846f, - 0.208248f, 0.443368f, 0.715601f, 0.410520f, 0.191007f, 0.967494f, 0.650750f, 0.865460f, 0.025242f, 0.266906f, - 0.502071f, 0.067449f, 0.993033f, 0.236462f, 0.374292f, 0.214012f, 0.105446f, 0.232480f, 0.300610f, 0.634442f, - 0.281235f, 0.362277f, 0.005943f, 0.365719f, 0.533886f, 0.162016f, 0.597433f, 0.293152f, 0.632050f, 0.026197f, - 0.887593f, 0.016119f, 0.126958f, 0.777162f, 0.045895f, 0.710999f, 0.971046f, 0.871683f, 0.710162f, 0.958510f, - 0.429813f, 0.872879f, 0.355958f, 0.929764f, 0.148778f, 0.940029f, 0.832716f, 0.846055f, 0.123923f, 0.596487f, - 0.016392f, 0.721184f, 0.007738f, 0.084822f, 0.225498f, 0.875125f, 0.363576f, 0.539960f, 0.568103f, 0.225463f, - 0.572147f, 0.660952f, 0.298245f, 0.418627f, 0.453089f, 0.932351f, 0.587494f, 0.948252f, 0.556035f, 0.500561f, - 0.003532f, 0.480889f, 0.927455f, 0.198366f, 0.052091f, 0.406779f, 0.372396f, 0.857153f, 0.026611f, 0.920149f, - 0.680903f, 0.904226f, 0.607529f, 0.811953f, 0.335544f, 0.349566f, 0.389874f, 0.754797f, 0.369291f, 0.242220f, - 0.937668f, 0.908011f, 0.348797f, 0.634638f, 0.273842f, 0.206115f, 0.336340f, 0.327100f, 0.882276f, 0.822304f, - 0.709623f, 0.959345f, 0.422543f, 0.245033f, 0.117398f, 0.301053f, 0.145264f, 0.092186f, 0.602932f, 0.364187f, - 0.564570f, 0.191336f, 0.676906f, 0.215505f, 0.278024f, 0.741760f, 0.559738f, 0.334836f, 0.542989f, 0.693985f, - 0.912132f, 0.580713f, 0.232686f, 0.746698f, 0.777769f, 0.200401f, 0.820574f, 0.464935f, 0.779767f, 0.237478f, - 0.332580f, 0.953697f, 0.657815f, 0.772878f, 0.688374f, 0.204304f, 0.470689f, 0.808964f, 0.675035f, 0.006028f, - 0.087408f, 0.346795f, 0.944366f, 0.491190f, 0.270176f, 0.360424f, 0.210653f, 0.421200f, 0.218035f, 0.845753f, - 0.456271f, 0.279802f, 0.932892f, 0.314351f, 0.909715f, 0.043418f, 0.707115f, 0.483889f, 0.444221f, 0.036323f, - 0.040683f, 0.332754f, 0.947120f, 0.617660f, 0.368875f, 0.611977f, 0.206132f, 0.165066f, 0.361817f, 0.863353f, - 0.509402f, 0.296902f, 0.950252f, 0.815966f, 0.322974f, 0.972098f, 0.987351f, 0.408660f, 0.655923f, 0.405653f, - 0.257348f, 0.082653f, 0.263610f, 0.271480f, 0.398639f, 0.184886f, 0.953818f, 0.102880f, 0.625209f, 0.441697f, - 0.423518f, 0.371992f, 0.868315f, 0.280477f, 0.020576f, 0.918097f, 0.864480f, 0.276902f, 0.523488f, 0.109088f, - 0.093427f, 0.837466f, 0.410266f, 0.661717f, 0.943201f, 0.245131f, 0.013160f, 0.024148f, 0.709386f, 0.924552f, - 0.467330f, 0.375109f, 0.542860f, 0.858917f, 0.652154f, 0.232980f, 0.774580f, 0.134613f, 0.165560f, 0.612682f, - 0.238783f, 0.704779f, 0.349519f, 0.277424f, 0.998918f, 0.040616f, 0.645823f, 0.038700f, 0.760210f, 0.230090f, - 0.089832f, 0.648450f, 0.732601f, 0.678095f, 0.051901f, 0.294307f, 0.451088f, 0.287103f, 0.810513f, 0.131115f, - 0.612179f, 0.988215f, 0.902557f, 0.222157f, 0.000082f, 0.980597f, 0.882713f, 0.919472f, 0.415504f, 0.744615f, - 0.212831f, 0.392304f, 0.851548f, 0.127612f, 0.893865f, 0.496508f, 0.426096f, 0.305646f, 0.916849f, 0.517623f, - 0.804026f, 0.857652f, 0.922382f, 0.303381f, 0.339811f, 0.595074f, 0.441324f, 0.932843f, 0.397564f, 0.477778f, - 0.617186f, 0.404739f, 0.992478f, 0.098851f, 0.220603f, 0.322655f, 0.147723f, 0.284219f, 0.779245f, 0.522892f, - 0.033954f, 0.982623f, 0.616006f, 0.058939f, 0.661169f, 0.378369f, 0.135673f, 0.563665f, 0.727080f, 0.671127f, - 0.247513f, 0.524866f, 0.537663f, 0.716803f, 0.359867f, 0.797733f, 0.627922f, 0.038332f, 0.546479f, 0.861912f, - 0.567574f, 0.175828f, 0.510376f, 0.756946f, 0.110105f, 0.817099f, 0.167482f, 0.534076f, 0.385743f, 0.248624f, - 0.647433f, 0.037392f, 0.760046f, 0.526941f, 0.875771f, 0.520718f, 0.035033f, 0.143601f, 0.795605f, 0.491976f, - 0.441879f, 0.318435f, 0.284549f, 0.965886f, 0.432969f, 0.884003f, 0.648163f, 0.858428f, 0.852450f, 0.956312f, - 0.697942f, 0.805397f, 0.733128f, 0.605227f, 0.717354f, 0.715750f, 0.040908f, 0.516111f, 0.792651f, 0.242962f, - 0.465148f, 0.434986f, 0.402787f, 0.121840f, 0.525712f, 0.446248f, 0.663393f, 0.549413f, 0.027543f, 0.031918f, - 0.701360f, 0.707581f, 0.959939f, 0.876705f, 0.468060f, 0.625907f, 0.457182f, 0.222946f, 0.376677f, 0.103884f, - 0.666527f, 0.192030f, 0.475468f, 0.967437f, 0.031669f, 0.151730f, 0.298579f, 0.941807f, 0.908842f, 0.162001f, - 0.981118f, 0.750748f, 0.539977f, 0.931703f, 0.880607f, 0.391316f, 0.656343f, 0.647385f, 0.326968f, 0.179390f, - 0.466810f, 0.263281f, 0.355065f, 0.954144f, 0.461138f, 0.684891f, 0.336230f, 0.995861f, 0.658768f, 0.196009f, - 0.098184f, 0.943181f, 0.944778f, 0.621328f, 0.016991f, 0.225535f, 0.801277f, 0.875460f, 0.453990f, 0.365521f, - 0.274225f, 0.116971f, 0.115745f, 0.952603f, 0.808626f, 0.164779f, 0.207050f, 0.655552f, 0.764664f, 0.810315f, - 0.163338f, 0.984128f, 0.227802f, 0.589415f, 0.587616f, 0.967362f, 0.657667f, 0.584904f, 0.518773f, 0.764658f, - 0.106055f, 0.002092f, 0.952489f, 0.498658f, 0.328335f, 0.368053f, 0.803843f, 0.382370f, 0.770169f, 0.440462f, - 0.844077f, 0.076204f, 0.481128f, 0.466850f, 0.264328f, 0.943615f, 0.905028f, 0.443596f, 0.097160f, 0.206783f, - 0.271492f, 0.484220f, 0.338377f, 0.774136f, 0.476027f, 0.870371f, 0.995782f, 0.219836f, 0.611671f, 0.847502f, - 0.945237f, 0.290086f, 0.727043f, 0.015016f, 0.879142f, 0.063939f, 0.733395f, 0.994610f, 0.501190f, 0.209334f, - 0.594644f, 0.624150f, 0.668073f, 0.172612f, 0.898713f, 0.620991f, 0.043569f, 0.684041f, 0.196084f, 0.027341f, - 0.550953f, 0.813314f, 0.859941f, 0.103521f, 0.663043f, 0.710075f, 0.294517f, 0.971364f, 0.278687f, 0.069982f, - 0.519280f, 0.694315f, 0.244660f, 0.338582f, 0.563628f, 0.886678f, 0.747326f, 0.209592f, 0.251777f, 0.523881f, - 0.768959f, 0.618762f, 0.501324f, 0.597125f, 0.756060f, 0.537080f, 0.897753f, 0.947067f, 0.915355f, 0.754518f, - 0.246321f, 0.385271f, 0.280000f, 0.657660f, 0.324222f, 0.754392f, 0.113509f, 0.775365f, 0.585902f, 0.835389f, - 0.430876f, 0.624964f, 0.554412f, 0.975671f, 0.755474f, 0.544813f, 0.174032f, 0.904114f, 0.205838f, 0.650043f, - 0.936472f, 0.223580f, 0.225924f, 0.851819f, 0.827655f, 0.351703f, 0.265096f, 0.127388f, 0.987936f, 0.835343f, - 0.899392f, 0.513679f, 0.114385f, 0.052580f, 0.330582f, 0.920330f, 0.947582f, 0.841164f, 0.158679f, 0.419923f, - 0.246243f, 0.205350f, 0.684826f, 0.486112f, 0.324910f, 0.100214f, 0.544763f, 0.347025f, 0.391096f, 0.310509f, - 0.387195f, 0.555860f, 0.014144f, 0.847647f, 0.921920f, 0.550530f, 0.268021f, 0.990239f, 0.383194f, 0.693655f, - 0.689953f, 0.434309f, 0.199158f, 0.966579f, 0.063691f, 0.485149f, 0.220731f, 0.293974f, 0.828527f, 0.367266f, - 0.083348f, 0.196309f, 0.860373f, 0.977029f, 0.267982f, 0.675409f, 0.081199f, 0.723466f, 0.416437f, 0.918160f, - 0.311536f, 0.941467f, 0.503247f, 0.348893f, 0.647020f, 0.249746f, 0.229764f, 0.196346f, 0.959900f, 0.492914f, - 0.751615f, 0.473992f, 0.587540f, 0.584139f, 0.979886f, 0.668433f, 0.239769f, 0.015198f, 0.218682f, 0.455520f, - 0.393420f, 0.812326f, 0.785557f, 0.089096f, 0.952011f, 0.527457f, 0.596404f, 0.405057f, 0.649501f, 0.871326f, - 0.673936f, 0.970099f, 0.701122f, 0.821721f, 0.045040f, 0.672699f, 0.654753f, 0.101746f, 0.842387f, 0.614172f, - 0.098328f, 0.594467f, 0.478416f, 0.233294f, 0.019756f, 0.365567f, 0.619851f, 0.329279f, 0.307255f, 0.751121f, - 0.758625f, 0.718766f, 0.101182f, 0.516166f, 0.557799f, 0.744805f, 0.903178f, 0.369039f, 0.428663f, 0.732767f, - 0.662636f, 0.557870f, 0.350140f, 0.195352f, 0.183807f, 0.081583f, 0.081201f, 0.845798f, 0.383673f, 0.060740f, - 0.896426f, 0.223270f, 0.268124f, 0.194498f, 0.967501f, 0.112540f, 0.722163f, 0.932089f, 0.668001f, 0.858727f, - 0.242447f, 0.673928f, 0.700871f, 0.458333f, 0.870546f, 0.694386f, 0.894878f, 0.753204f, 0.520290f, 0.498688f, - 0.453728f, 0.021647f, 0.535141f, 0.422973f, 0.157534f, 0.119070f, 0.449352f, 0.039913f, 0.986580f, 0.378121f, - 0.382109f, 0.051126f, 0.426672f, 0.015745f, 0.030094f, 0.339099f, 0.820969f, 0.458821f, 0.014841f, 0.163220f, - 0.739923f, 0.738294f, 0.754523f, 0.351669f, 0.352277f, 0.802076f, 0.398138f, 0.727191f, 0.581123f, 0.364342f, - 0.080007f, 0.116125f, 0.889559f, 0.452341f, 0.994005f, 0.363897f, 0.249954f, 0.350539f, 0.343086f, 0.637357f, - 0.012738f, 0.763269f, 0.416415f, 0.432239f, 0.481115f, 0.449212f, 0.497471f, 0.345904f, 0.453346f, 0.404651f, - 0.518243f, 0.623269f, 0.241041f, 0.508437f, 0.594622f, 0.016948f, 0.520494f, 0.239293f, 0.404539f, 0.826530f, - 0.326236f, 0.483217f, 0.024741f, 0.308751f, 0.639721f, 0.315162f, 0.205798f, 0.290656f, 0.954378f, 0.086802f, - 0.463358f, 0.058387f, 0.538658f, 0.146036f, 0.634085f, 0.264397f, 0.690915f, 0.347146f, 0.004168f, 0.294895f, - 0.081894f, 0.495040f, 0.288890f, 0.639992f, 0.499936f, 0.036045f, 0.318634f, 0.489059f, 0.572204f, 0.104871f, - 0.649971f, 0.343698f, 0.182921f, 0.805327f, 0.068623f, 0.929770f, 0.706266f, 0.475591f, 0.011161f, 0.390125f, - 0.645798f, 0.858913f, 0.617764f, 0.397707f}; + 0.696469f, 0.719469f, 0.480932f, 0.438572f, 0.182492f, 0.634401f, 0.722443f, 0.293714f, 0.430863f, 0.426351f, + 0.623953f, 0.866309f, 0.519485f, 0.603060f, 0.417022f, 0.669314f, 0.842342f, 0.194223f, 0.627249f, 0.556785f, + 0.318766f, 0.925132f, 0.304768f, 0.355915f, 0.151127f, 0.513128f, 0.321981f, 0.854452f, 0.171082f, 0.578551f, + 0.905342f, 0.901911f, 0.806969f, 0.600699f, 0.428347f, 0.093327f, 0.457412f, 0.708697f, 0.286537f, 0.664872f, + 0.438214f, 0.582671f, 0.750717f, 0.859389f, 0.081780f, 0.562218f, 0.467988f, 0.931932f, 0.378986f, 0.032198f, + 0.542636f, 0.769397f, 0.661168f, 0.425868f, 0.181629f, 0.417291f, 0.918397f, 0.313669f, 0.238250f, 0.301947f, + 0.005545f, 0.097038f, 0.798923f, 0.715601f, 0.650750f, 0.502071f, 0.374292f, 0.300610f, 0.005943f, 0.597433f, + 0.887593f, 0.045895f, 0.710162f, 0.355958f, 0.832716f, 0.016392f, 0.225498f, 0.568103f, 0.298245f, 0.587494f, + 0.003532f, 0.052091f, 0.026611f, 0.607529f, 0.389874f, 0.937668f, 0.273842f, 0.882276f, 0.422543f, 0.145264f, + 0.564570f, 0.278024f, 0.542989f, 0.232686f, 0.820574f, 0.332580f, 0.688374f, 0.675035f, 0.944366f, 0.210653f, + 0.456271f, 0.909715f, 0.444221f, 0.947120f, 0.206132f, 0.509402f, 0.322974f, 0.655923f, 0.263610f, 0.953818f, + 0.423518f, 0.020576f, 0.523488f, 0.410266f, 0.013160f, 0.467330f, 0.652154f, 0.165560f, 0.349519f, 0.645823f, + 0.089832f, 0.051901f, 0.810513f, 0.902557f, 0.882713f, 0.212831f, 0.893865f, 0.916849f, 0.286139f, 0.423106f, + 0.392118f, 0.059678f, 0.175452f, 0.849432f, 0.322959f, 0.630976f, 0.493685f, 0.893389f, 0.115618f, 0.250455f, + 0.612895f, 0.545068f, 0.681301f, 0.585937f, 0.083195f, 0.572457f, 0.723416f, 0.158960f, 0.691970f, 0.841670f, + 0.398186f, 0.762548f, 0.398876f, 0.666625f, 0.661564f, 0.384838f, 0.829113f, 0.521533f, 0.207636f, 0.983631f, + 0.394370f, 0.865864f, 0.204543f, 0.296861f, 0.753526f, 0.839243f, 0.306470f, 0.887857f, 0.765096f, 0.814844f, + 0.574064f, 0.821504f, 0.138416f, 0.122244f, 0.807938f, 0.582175f, 0.668384f, 0.744781f, 0.066774f, 0.573774f, + 0.049097f, 0.788187f, 0.321319f, 0.989035f, 0.091296f, 0.047340f, 0.807791f, 0.980582f, 0.484909f, 0.461909f, + 0.798846f, 0.410520f, 0.865460f, 0.067449f, 0.214012f, 0.634442f, 0.365719f, 0.293152f, 0.016119f, 0.710999f, + 0.958510f, 0.929764f, 0.846055f, 0.721184f, 0.875125f, 0.225463f, 0.418627f, 0.948252f, 0.480889f, 0.406779f, + 0.920149f, 0.811953f, 0.754797f, 0.908011f, 0.206115f, 0.822304f, 0.245033f, 0.092186f, 0.191336f, 0.741760f, + 0.693985f, 0.746698f, 0.464935f, 0.953697f, 0.204304f, 0.006028f, 0.491190f, 0.421200f, 0.279802f, 0.043418f, + 0.036323f, 0.617660f, 0.165066f, 0.296902f, 0.972098f, 0.405653f, 0.271480f, 0.102880f, 0.371992f, 0.918097f, + 0.109088f, 0.661717f, 0.024148f, 0.375109f, 0.232980f, 0.612682f, 0.277424f, 0.038700f, 0.648450f, 0.294307f, + 0.131115f, 0.222157f, 0.919472f, 0.392304f, 0.496508f, 0.517623f, 0.226851f, 0.980764f, 0.343178f, 0.398044f, + 0.531551f, 0.724455f, 0.361789f, 0.092105f, 0.425830f, 0.944160f, 0.317285f, 0.483034f, 0.120629f, 0.342764f, + 0.875457f, 0.624904f, 0.763683f, 0.095713f, 0.016129f, 0.153071f, 0.554383f, 0.357398f, 0.704959f, 0.593177f, + 0.240856f, 0.105908f, 0.846506f, 0.316788f, 0.338671f, 0.002688f, 0.292489f, 0.257542f, 0.731073f, 0.983522f, + 0.450636f, 0.927584f, 0.741862f, 0.165938f, 0.665261f, 0.696311f, 0.565642f, 0.337066f, 0.751644f, 0.909872f, + 0.399379f, 0.201400f, 0.007426f, 0.206096f, 0.029320f, 0.472913f, 0.653365f, 0.102635f, 0.792299f, 0.411569f, + 0.845533f, 0.236600f, 0.463653f, 0.241686f, 0.894978f, 0.539505f, 0.988329f, 0.963004f, 0.208248f, 0.191007f, + 0.025242f, 0.993033f, 0.105446f, 0.281235f, 0.533886f, 0.632050f, 0.126958f, 0.971046f, 0.429813f, 0.148778f, + 0.123923f, 0.007738f, 0.363576f, 0.572147f, 0.453089f, 0.556035f, 0.927455f, 0.372396f, 0.680903f, 0.335544f, + 0.369291f, 0.348797f, 0.336340f, 0.709623f, 0.117398f, 0.602932f, 0.676906f, 0.559738f, 0.912132f, 0.777769f, + 0.779767f, 0.657815f, 0.470689f, 0.087408f, 0.270176f, 0.218035f, 0.932892f, 0.707115f, 0.040683f, 0.368875f, + 0.361817f, 0.950252f, 0.987351f, 0.257348f, 0.398639f, 0.625209f, 0.868315f, 0.864480f, 0.093427f, 0.943201f, + 0.709386f, 0.542860f, 0.774580f, 0.238783f, 0.998918f, 0.760210f, 0.732601f, 0.451088f, 0.612179f, 0.000082f, + 0.415504f, 0.851548f, 0.426096f, 0.804026f, 0.551315f, 0.684830f, 0.729050f, 0.737995f, 0.531828f, 0.611024f, + 0.228263f, 0.433701f, 0.312261f, 0.501837f, 0.414826f, 0.985560f, 0.826341f, 0.304121f, 0.510422f, 0.674689f, + 0.243666f, 0.885327f, 0.594432f, 0.695530f, 0.388951f, 0.043591f, 0.995358f, 0.691702f, 0.343456f, 0.130895f, + 0.553257f, 0.354265f, 0.552370f, 0.988345f, 0.520010f, 0.564359f, 0.161069f, 0.079366f, 0.547764f, 0.569004f, + 0.048579f, 0.780998f, 0.111392f, 0.440328f, 0.084904f, 0.927577f, 0.079149f, 0.128631f, 0.424307f, 0.811644f, + 0.551593f, 0.717758f, 0.635900f, 0.121754f, 0.996086f, 0.699834f, 0.518717f, 0.481026f, 0.186904f, 0.916832f, + 0.502216f, 0.095530f, 0.043223f, 0.626309f, 0.375186f, 0.341831f, 0.443368f, 0.967494f, 0.266906f, 0.236462f, + 0.232480f, 0.362277f, 0.162016f, 0.026197f, 0.777162f, 0.871683f, 0.872879f, 0.940029f, 0.596487f, 0.084822f, + 0.539960f, 0.660952f, 0.932351f, 0.500561f, 0.198366f, 0.857153f, 0.904226f, 0.349566f, 0.242220f, 0.634638f, + 0.327100f, 0.959345f, 0.301053f, 0.364187f, 0.215505f, 0.334836f, 0.580713f, 0.200401f, 0.237478f, 0.772878f, + 0.808964f, 0.346795f, 0.360424f, 0.845753f, 0.314351f, 0.483889f, 0.332754f, 0.611977f, 0.863353f, 0.815966f, + 0.408660f, 0.082653f, 0.184886f, 0.441697f, 0.280477f, 0.276902f, 0.837466f, 0.245131f, 0.924552f, 0.858917f, + 0.134613f, 0.704779f, 0.040616f, 0.230090f, 0.678095f, 0.287103f, 0.988215f, 0.980597f, 0.744615f, 0.127612f, + 0.305646f, 0.857652f, 0.922382f, 0.441324f, 0.617186f, 0.220603f, 0.779245f, 0.616006f, 0.135673f, 0.247513f, + 0.359867f, 0.546479f, 0.510376f, 0.167482f, 0.647433f, 0.875771f, 0.795605f, 0.284549f, 0.648163f, 0.697942f, + 0.717354f, 0.792651f, 0.402787f, 0.663393f, 0.701360f, 0.468060f, 0.376677f, 0.475468f, 0.298579f, 0.981118f, + 0.880607f, 0.326968f, 0.355065f, 0.336230f, 0.098184f, 0.016991f, 0.453990f, 0.115745f, 0.207050f, 0.163338f, + 0.587616f, 0.518773f, 0.952489f, 0.803843f, 0.844077f, 0.264328f, 0.097160f, 0.338377f, 0.995782f, 0.945237f, + 0.879142f, 0.501190f, 0.668073f, 0.043569f, 0.550953f, 0.663043f, 0.278687f, 0.244660f, 0.747326f, 0.768959f, + 0.756060f, 0.915355f, 0.280000f, 0.113509f, 0.430876f, 0.755474f, 0.205838f, 0.225924f, 0.265096f, 0.899392f, + 0.330582f, 0.158679f, 0.684826f, 0.544763f, 0.387195f, 0.921920f, 0.383194f, 0.199158f, 0.220731f, 0.083348f, + 0.267982f, 0.416437f, 0.503247f, 0.229764f, 0.751615f, 0.979886f, 0.218682f, 0.785557f, 0.596404f, 0.673936f, + 0.045040f, 0.842387f, 0.478416f, 0.619851f, 0.758625f, 0.557799f, 0.428663f, 0.350140f, 0.081201f, 0.896426f, + 0.967501f, 0.668001f, 0.700871f, 0.894878f, 0.453728f, 0.157534f, 0.986580f, 0.426672f, 0.820969f, 0.739923f, + 0.352277f, 0.581123f, 0.889559f, 0.249954f, 0.012738f, 0.481115f, 0.453346f, 0.241041f, 0.520494f, 0.326236f, + 0.639721f, 0.954378f, 0.538658f, 0.690915f, 0.081894f, 0.499936f, 0.572204f, 0.182921f, 0.706266f, 0.645798f, + 0.303381f, 0.932843f, 0.404739f, 0.322655f, 0.522892f, 0.058939f, 0.563665f, 0.524866f, 0.797733f, 0.861912f, + 0.756946f, 0.534076f, 0.037392f, 0.520718f, 0.491976f, 0.965886f, 0.858428f, 0.805397f, 0.715750f, 0.242962f, + 0.121840f, 0.549413f, 0.707581f, 0.625907f, 0.103884f, 0.967437f, 0.941807f, 0.750748f, 0.391316f, 0.179390f, + 0.954144f, 0.995861f, 0.943181f, 0.225535f, 0.365521f, 0.952603f, 0.655552f, 0.984128f, 0.967362f, 0.764658f, + 0.498658f, 0.382370f, 0.076204f, 0.943615f, 0.206783f, 0.774136f, 0.219836f, 0.290086f, 0.063939f, 0.209334f, + 0.172612f, 0.684041f, 0.813314f, 0.710075f, 0.069982f, 0.338582f, 0.209592f, 0.618762f, 0.537080f, 0.754518f, + 0.657660f, 0.775365f, 0.624964f, 0.544813f, 0.650043f, 0.851819f, 0.127388f, 0.513679f, 0.920330f, 0.419923f, + 0.486112f, 0.347025f, 0.555860f, 0.550530f, 0.693655f, 0.966579f, 0.293974f, 0.196309f, 0.675409f, 0.918160f, + 0.348893f, 0.196346f, 0.473992f, 0.668433f, 0.455520f, 0.089096f, 0.405057f, 0.970099f, 0.672699f, 0.614172f, + 0.233294f, 0.329279f, 0.718766f, 0.744805f, 0.732767f, 0.195352f, 0.845798f, 0.223270f, 0.112540f, 0.858727f, + 0.458333f, 0.753204f, 0.021647f, 0.119070f, 0.378121f, 0.015745f, 0.458821f, 0.738294f, 0.802076f, 0.364342f, + 0.452341f, 0.350539f, 0.763269f, 0.449212f, 0.404651f, 0.508437f, 0.239293f, 0.483217f, 0.315162f, 0.086802f, + 0.146036f, 0.347146f, 0.495040f, 0.036045f, 0.104871f, 0.805327f, 0.475591f, 0.858913f, 0.339811f, 0.397564f, + 0.992478f, 0.147723f, 0.033954f, 0.661169f, 0.727080f, 0.537663f, 0.627922f, 0.567574f, 0.110105f, 0.385743f, + 0.760046f, 0.035033f, 0.441879f, 0.432969f, 0.852450f, 0.733128f, 0.040908f, 0.465148f, 0.525712f, 0.027543f, + 0.959939f, 0.457182f, 0.666527f, 0.031669f, 0.908842f, 0.539977f, 0.656343f, 0.466810f, 0.461138f, 0.658768f, + 0.944778f, 0.801277f, 0.274225f, 0.808626f, 0.764664f, 0.227802f, 0.657667f, 0.106055f, 0.328335f, 0.770169f, + 0.481128f, 0.905028f, 0.271492f, 0.476027f, 0.611671f, 0.727043f, 0.733395f, 0.594644f, 0.898713f, 0.196084f, + 0.859941f, 0.294517f, 0.519280f, 0.563628f, 0.251777f, 0.501324f, 0.897753f, 0.246321f, 0.324222f, 0.585902f, + 0.554412f, 0.174032f, 0.936472f, 0.827655f, 0.987936f, 0.114385f, 0.947582f, 0.246243f, 0.324910f, 0.391096f, + 0.014144f, 0.268021f, 0.689953f, 0.063691f, 0.828527f, 0.860373f, 0.081199f, 0.311536f, 0.647020f, 0.959900f, + 0.587540f, 0.239769f, 0.393420f, 0.952011f, 0.649501f, 0.701122f, 0.654753f, 0.098328f, 0.019756f, 0.307255f, + 0.101182f, 0.903178f, 0.662636f, 0.183807f, 0.383673f, 0.268124f, 0.722163f, 0.242447f, 0.870546f, 0.520290f, + 0.535141f, 0.449352f, 0.382109f, 0.030094f, 0.014841f, 0.754523f, 0.398138f, 0.080007f, 0.994005f, 0.343086f, + 0.416415f, 0.497471f, 0.518243f, 0.594622f, 0.404539f, 0.024741f, 0.205798f, 0.463358f, 0.634085f, 0.004168f, + 0.288890f, 0.318634f, 0.649971f, 0.068623f, 0.011161f, 0.617764f, 0.595074f, 0.477778f, 0.098851f, 0.284219f, + 0.982623f, 0.378369f, 0.671127f, 0.716803f, 0.038332f, 0.175828f, 0.817099f, 0.248624f, 0.526941f, 0.143601f, + 0.318435f, 0.884003f, 0.956312f, 0.605227f, 0.516111f, 0.434986f, 0.446248f, 0.031918f, 0.876705f, 0.222946f, + 0.192030f, 0.151730f, 0.162001f, 0.931703f, 0.647385f, 0.263281f, 0.684891f, 0.196009f, 0.621328f, 0.875460f, + 0.116971f, 0.164779f, 0.810315f, 0.589415f, 0.584904f, 0.002092f, 0.368053f, 0.440462f, 0.466850f, 0.443596f, + 0.484220f, 0.870371f, 0.847502f, 0.015016f, 0.994610f, 0.624150f, 0.620991f, 0.027341f, 0.103521f, 0.971364f, + 0.694315f, 0.886678f, 0.523881f, 0.597125f, 0.947067f, 0.385271f, 0.754392f, 0.835389f, 0.975671f, 0.904114f, + 0.223580f, 0.351703f, 0.835343f, 0.052580f, 0.841164f, 0.205350f, 0.100214f, 0.310509f, 0.847647f, 0.990239f, + 0.434309f, 0.485149f, 0.367266f, 0.977029f, 0.723466f, 0.941467f, 0.249746f, 0.492914f, 0.584139f, 0.015198f, + 0.812326f, 0.527457f, 0.871326f, 0.821721f, 0.101746f, 0.594467f, 0.365567f, 0.751121f, 0.516166f, 0.369039f, + 0.557870f, 0.081583f, 0.060740f, 0.194498f, 0.932089f, 0.673928f, 0.694386f, 0.498688f, 0.422973f, 0.039913f, + 0.051126f, 0.339099f, 0.163220f, 0.351669f, 0.727191f, 0.116125f, 0.363897f, 0.637357f, 0.432239f, 0.345904f, + 0.623269f, 0.016948f, 0.826530f, 0.308751f, 0.290656f, 0.058387f, 0.264397f, 0.294895f, 0.639992f, 0.489059f, + 0.343698f, 0.929770f, 0.390125f, 0.397707f}; std::vector gamma_data = { 0.447359f, 0.873295f, 0.351357f, 0.065158f, 0.442673f, 0.998459f, 0.379773f, 0.193055f, 0.045130f, 0.170969f, @@ -143,248 +143,251 @@ TEST(GroupNormTest, GroupNorm_128) { 0.691057f, 0.918779f, 0.017400f, 0.799489f, 0.089403f, 0.916554f, 0.612013f, 0.162069f}; std::vector beta_data = { - 0.039410f, 0.827821f, 0.139492f, 0.939541f, 0.090865f, 0.837978f, 0.423533f, 0.872735f, 0.768574f, 0.852882f, 0.470242f, 0.713768f, 0.318668f, 0.047173f, - 0.232400f, 0.001362f, 0.363028f, 0.493829f, 0.019407f, 0.007730f, 0.686464f, 0.100436f, 0.073846f, 0.495598f, 0.718159f, 0.977165f, 0.295397f, 0.117518f, - 0.068537f, 0.207511f, 0.100055f, 0.003384f, 0.285074f, 0.164207f, 0.018250f, 0.354632f, 0.825916f, 0.303662f, 0.710100f, 0.728735f, 0.025556f, 0.961785f, - 0.139009f, 0.717465f, 0.379443f, 0.868223f, 0.994961f, 0.193323f, 0.819456f, 0.505503f, 0.965431f, 0.658089f, 0.593238f, 0.229523f, 0.718700f, 0.288201f, - 0.845759f, 0.977264f, 0.007793f, 0.954633f, 0.358460f, 0.488316f, 0.924086f, 0.775958f, 0.243222f, 0.096853f, 0.841226f, 0.747060f, 0.858339f, 0.384041f, - 0.492114f, 0.465019f, 0.314722f, 0.335672f, 0.718649f, 0.753071f, 0.863854f, 0.844902f, 0.753938f, 0.332778f, 0.710046f, 0.972624f, 0.916240f, 0.971488f, - 0.036208f, 0.611599f, 0.215343f, 0.246560f, 0.844061f, 0.750192f, 0.328802f, 0.519915f, 0.188330f, 0.003827f, 0.899958f, 0.709642f, 0.528818f, 0.054099f, - 0.420840f, 0.380042f, 0.171547f, 0.156188f, 0.173178f, 0.596836f, 0.124704f, 0.238549f, 0.946272f, 0.219462f, 0.763857f, 0.598040f, 0.413157f, 0.595286f, - 0.133620f, 0.484188f, 0.972134f, 0.427721f, 0.242881f, 0.927507f, 0.610774f, 0.727857f, 0.543405f, 0.011202f, 0.755700f, 0.978697f, 0.716188f, 0.808757f, - 0.851587f, 0.999201f}; + 0.039410f, 0.827821f, 0.139492f, 0.939541f, 0.090865f, 0.837978f, 0.423533f, 0.872735f, 0.768574f, 0.852882f, + 0.470242f, 0.713768f, 0.318668f, 0.047173f, 0.232400f, 0.001362f, 0.363028f, 0.493829f, 0.019407f, 0.007730f, + 0.686464f, 0.100436f, 0.073846f, 0.495598f, 0.718159f, 0.977165f, 0.295397f, 0.117518f, 0.068537f, 0.207511f, + 0.100055f, 0.003384f, 0.285074f, 0.164207f, 0.018250f, 0.354632f, 0.825916f, 0.303662f, 0.710100f, 0.728735f, + 0.025556f, 0.961785f, 0.139009f, 0.717465f, 0.379443f, 0.868223f, 0.994961f, 0.193323f, 0.819456f, 0.505503f, + 0.965431f, 0.658089f, 0.593238f, 0.229523f, 0.718700f, 0.288201f, 0.845759f, 0.977264f, 0.007793f, 0.954633f, + 0.358460f, 0.488316f, 0.924086f, 0.775958f, 0.243222f, 0.096853f, 0.841226f, 0.747060f, 0.858339f, 0.384041f, + 0.492114f, 0.465019f, 0.314722f, 0.335672f, 0.718649f, 0.753071f, 0.863854f, 0.844902f, 0.753938f, 0.332778f, + 0.710046f, 0.972624f, 0.916240f, 0.971488f, 0.036208f, 0.611599f, 0.215343f, 0.246560f, 0.844061f, 0.750192f, + 0.328802f, 0.519915f, 0.188330f, 0.003827f, 0.899958f, 0.709642f, 0.528818f, 0.054099f, 0.420840f, 0.380042f, + 0.171547f, 0.156188f, 0.173178f, 0.596836f, 0.124704f, 0.238549f, 0.946272f, 0.219462f, 0.763857f, 0.598040f, + 0.413157f, 0.595286f, 0.133620f, 0.484188f, 0.972134f, 0.427721f, 0.242881f, 0.927507f, 0.610774f, 0.727857f, + 0.543405f, 0.011202f, 0.755700f, 0.978697f, 0.716188f, 0.808757f, 0.851587f, 0.999201f}; std::vector norm_data = { - 0.406306f, -0.397960f, -0.514167f, 0.121796f, 1.632045f, 0.498094f, 2.631821f, 1.499508f, 0.095849f, - -0.040874f, -0.116213f, 0.477808f, 0.919355f, 0.811189f, 0.907785f, 1.004834f, -0.458834f, -0.472885f, - 0.237840f, 0.238391f, 1.632483f, 2.600490f, 2.037882f, 1.527244f, 0.876482f, 0.192458f, 0.258945f, - 0.030314f, 0.729815f, 1.023374f, 0.554331f, 0.851662f, 0.750835f, 0.762038f, 0.749937f, 0.729685f, - 0.782631f, 1.098150f, 1.132450f, 0.833627f, 0.590117f, -0.060817f, 0.197422f, 0.322326f, 1.476163f, - 0.078648f, 0.606424f, 1.746771f, 0.183714f, 0.518953f, -1.247748f, 1.284994f, 0.057787f, 0.044398f, - -0.002311f, -0.011233f, -0.474648f, 0.859423f, 1.839518f, -0.003167f, 0.143954f, 0.038016f, 0.087527f, - 0.150784f, 0.561618f, 0.176986f, 0.521764f, 0.258291f, 0.031635f, 0.714081f, -0.146106f, 1.278590f, - 0.426744f, 0.648229f, -0.980738f, 0.351162f, 0.118848f, -0.296623f, -0.302773f, 0.263747f, 0.054676f, - 1.040141f, 0.676835f, 0.240001f, 0.526575f, 0.429690f, -0.132461f, -0.496733f, -0.827396f, -0.494966f, - 0.596699f, 1.630098f, -0.206514f, 1.206059f, 0.617694f, 0.959952f, 0.631899f, 0.709146f, 0.659876f, - 0.691867f, 1.033381f, 1.134488f, 0.765150f, 0.781609f, -0.028056f, 1.010104f, 1.575500f, 0.678993f, - 0.117742f, 0.117495f, 0.117460f, 0.117479f, -0.928939f, 0.857719f, -0.473908f, 0.106319f, 0.254703f, - 0.187595f, -0.423069f, 0.737017f, 1.002641f, -0.713799f, -0.505049f, 0.054679f, 0.056505f, 0.068448f, - -0.037672f, 0.007170f, 0.502409f, 0.201653f, 0.447086f, 0.031594f, 0.186869f, 0.252268f, 0.281287f, - 0.058290f, -0.032152f, -0.172338f, -0.018190f, 0.042648f, -0.201724f, 0.070818f, 0.915389f, 0.435231f, - 0.683548f, 1.228964f, 1.207481f, -0.069486f, 0.900928f, 1.349056f, -0.962214f, 1.149115f, 0.126877f, - 0.173722f, 1.016921f, -0.284731f, 1.073324f, 1.663625f, 1.156551f, 0.478892f, -0.017409f, 0.077027f, - 0.019404f, -0.119480f, 0.957481f, 1.191751f, 0.709657f, 1.305503f, 0.710492f, 0.094092f, 0.713726f, - -1.632824f, 1.254686f, 1.179984f, 1.354227f, -0.186219f, -0.620889f, -0.462022f, 0.270004f, 0.339929f, - 0.882544f, 0.831658f, 0.840813f, 0.911391f, 1.003820f, 1.105588f, 0.865947f, 1.028848f, 0.385277f, - 0.249245f, 0.102975f, 0.301977f, 0.814893f, 0.829719f, 0.796979f, 0.828055f, -0.841305f, 1.360636f, - 0.520542f, -0.564568f, 1.028838f, 0.624319f, 1.122967f, 1.414307f, 1.664626f, 1.011229f, -0.562413f, - 1.432279f, 0.982238f, -0.634975f, 1.328713f, 0.605853f, 0.150513f, 0.475544f, 0.137686f, 0.199995f, - -0.461095f, 0.034839f, 1.895931f, -0.442368f, -0.012286f, 1.765260f, -0.574054f, 1.540784f, 1.094831f, - 0.660444f, 0.856002f, 0.876256f, 0.900296f, 0.743193f, 0.857834f, 0.771619f, -0.437987f, 0.795097f, - 0.983861f, -0.860229f, 0.919201f, 1.088295f, 0.978393f, 1.000022f, -0.604762f, 0.300263f, 1.250703f, - 0.093107f, 0.398245f, 0.476736f, 0.584533f, 0.450905f, 1.126501f, 1.126446f, 0.704302f, 0.872359f, - 1.388226f, 0.453643f, -0.218810f, 2.159872f, 0.740287f, 1.137416f, -0.416660f, 0.030324f, 0.352386f, - -0.572652f, 1.397336f, -0.212928f, 0.833504f, 0.673148f, 0.564530f, 0.691624f, 0.614170f, 1.159168f, - 0.582539f, 0.714844f, 0.687727f, 0.829472f, 0.895726f, 0.749217f, 0.626510f, 0.160861f, 0.679485f, - -0.247668f, 0.563813f, 0.424527f, 0.442242f, 0.546163f, 0.408836f, 0.503895f, 0.541062f, 0.526861f, - 0.651389f, 1.131327f, 0.109609f, 0.965844f, 0.307533f, 0.397239f, 0.275143f, 0.398844f, 1.158524f, - 1.178295f, 0.107930f, 0.808378f, 0.360064f, 0.893187f, 0.353517f, 0.411826f, 0.588918f, 1.147333f, - 0.707609f, 0.859227f, 0.904664f, 0.005007f, 0.915281f, 1.148453f, 0.418446f, 0.581892f, 0.628682f, - 1.279391f, 0.420879f, 1.174909f, 0.355126f, 0.239180f, 0.495571f, 0.703488f, 0.897993f, 0.580433f, - 0.796672f, 0.937277f, 0.923647f, 1.115814f, 0.759542f, 1.057870f, 0.977992f, 1.052553f, 0.996513f, - 1.042361f, 0.935513f, 0.938658f, -0.328335f, 0.414783f, -0.370250f, -0.629015f, 1.636925f, 1.554468f, - -0.000332f, 0.794400f, -0.644444f, -0.841804f, -0.462323f, -0.489248f, 1.350502f, 1.139242f, 0.742310f, - 1.621988f, 0.891792f, 0.742398f, 0.634979f, 0.789545f, 0.600690f, 0.564714f, 0.910902f, 0.749079f, - 0.795602f, -0.081046f, 1.059454f, -0.024277f, 0.142066f, 2.137630f, 1.354346f, 0.386545f, -0.015730f, - 0.467942f, 1.166715f, 0.105109f, -0.867947f, 0.330163f, 0.402587f, -0.943201f, 1.039989f, 0.807147f, - 1.013271f, 0.658228f, 0.261774f, 1.276604f, 0.793169f, 0.981167f, 1.182381f, -0.094400f, 0.608214f, - 1.500447f, 0.375100f, -0.540889f, -0.429466f, -0.074319f, 0.493101f, 0.428099f, 0.396397f, 0.409342f, - -0.112225f, 0.338536f, -0.096419f, 1.247461f, 0.136779f, -0.296175f, 1.306138f, -0.211410f, 1.225890f, - -0.883684f, 0.732527f, 0.188935f, 0.158450f, -0.070659f, -0.068210f, 0.095841f, 1.142486f, 0.765356f, - 0.480573f, 0.758850f, -0.296101f, -0.351806f, -0.084915f, 0.595416f, 0.228868f, -0.067355f, 0.843406f, - 0.656214f, 0.873088f, 1.118756f, 1.124528f, 0.905517f, 0.397857f, 0.077982f, -0.111570f, -0.334851f, - 0.432766f, 0.446440f, 0.667385f, 0.295979f, 1.815673f, -0.258010f, 1.014872f, 0.567667f, 0.353312f, - 0.252682f, 1.221989f, 0.073956f, -0.006854f, 1.239576f, 1.165116f, 0.349117f, 0.251850f, -0.979634f, - -1.026174f, 1.184909f, 0.343477f, 0.825275f, 1.364619f, 0.027066f, -0.497336f, -0.463020f, 1.676924f, - 2.348872f, 0.382225f, 0.125961f, 0.592108f, 1.470366f, 0.758787f, -0.208515f, 1.041303f, -0.435509f, - 0.117172f, 1.494655f, 0.342757f, 1.778383f, 0.342274f, 0.097464f, 2.547432f, -0.706661f, 0.892228f, - 0.432844f, 0.978781f, 0.577661f, -0.293386f, 0.867343f, 1.042198f, 0.928943f, -1.206122f, -0.536458f, - -0.103338f, -0.556358f, 0.772336f, 0.736790f, 0.761959f, 0.781633f, 1.964310f, 0.328702f, -0.205143f, - 2.151912f, 0.807267f, 0.819557f, 0.651057f, 0.761094f, -0.553660f, 0.061518f, 1.635670f, -0.845767f, - 1.500599f, 0.591131f, 0.429972f, 0.154289f, 1.184999f, 0.943027f, 1.116617f, 1.149119f, 0.798352f, - -0.237060f, -0.176123f, 0.250859f, 0.738550f, 2.343516f, 0.595660f, 0.857584f, 0.334614f, 0.055512f, - 0.827656f, -0.346350f, 0.879107f, 0.903969f, 0.861351f, 0.894605f, 0.544361f, 0.112821f, -0.710248f, - 0.886723f, 1.241048f, -0.874084f, 1.412525f, 0.338762f, -0.116848f, 0.501252f, 0.737254f, 0.656447f, - 0.680143f, 0.883760f, 0.893155f, 1.024669f, 0.749525f, 0.825862f, 0.796258f, 0.693469f, 0.903967f, - 1.112298f, 0.917900f, 0.659168f, 0.521876f, 0.830550f, 0.020787f, 0.905854f, 0.044571f, 0.857847f, - 0.528776f, 0.224581f, 0.636013f, -0.774066f, 0.896313f, 0.357502f, 0.101543f, 0.048746f, -0.023476f, - -0.007332f, 1.160492f, 0.173347f, 0.010474f, -0.390864f, -0.183245f, 0.374310f, -0.061789f, 0.307303f, - 0.374511f, 0.508790f, 0.504972f, 0.571301f, 0.647929f, 0.892303f, 0.727948f, 0.437075f, 0.272462f, - 0.267807f, -1.691226f, -0.311736f, 0.221596f, -0.501987f, -0.209513f, -0.249217f, 0.477392f, -0.221902f, - 0.783358f, 0.585570f, 0.293685f, 0.168966f, -0.402073f, -0.397286f, 0.793616f, 0.814484f, 1.660988f, - 1.381788f, 0.434287f, 0.951160f, 0.398667f, -0.368342f, 0.685965f, 0.628689f, 0.746822f, 0.647196f, - 0.952972f, 1.171188f, 0.756122f, 0.809376f, -0.181046f, 1.143145f, 1.075280f, -0.462215f, 0.117678f, - 0.117596f, 0.117522f, 0.117660f, 1.207595f, -0.374746f, 0.482337f, 0.453367f, -0.074850f, -0.281733f, - 0.121187f, -0.164130f, -0.407813f, 1.347597f, -0.097000f, 0.558638f, -0.030066f, 0.084762f, 0.026081f, - -0.054476f, 0.048566f, 0.563618f, 0.564591f, 0.367439f, 0.067439f, 0.110448f, 0.229187f, 0.244487f, - 0.001379f, -0.044959f, -0.092778f, -0.175144f, -0.060172f, 0.876871f, 0.715658f, -0.005267f, 0.280818f, - 1.021856f, 1.202137f, 1.277564f, -0.846823f, 1.680601f, -0.648320f, 0.465179f, 0.816884f, 1.617434f, - 0.964561f, 0.811168f, 0.685541f, 1.269441f, -0.294534f, -0.541415f, 0.148579f, 0.006120f, -0.047344f, - -0.034877f, 1.228496f, 0.766407f, 1.191577f, 0.830097f, 1.213856f, -1.697397f, -0.162200f, -0.216335f, - 0.082768f, 1.538109f, 1.455440f, 0.466843f, -0.675884f, -0.396112f, -0.230969f, 0.311936f, 0.850093f, - 0.895946f, 0.864577f, 0.906072f, 1.127087f, 0.915749f, 1.022470f, 1.086701f, 0.347097f, 0.115267f, - 0.269888f, 0.017932f, 0.837999f, 0.798699f, 0.830973f, 0.843566f, 0.524987f, -0.323668f, 0.796731f, - 0.882529f, 1.104285f, 0.707952f, 1.288781f, 1.066624f, -0.759169f, 1.253857f, -0.279808f, -0.810174f, - 0.635460f, 1.336810f, 1.461457f, -0.560630f, 0.345593f, 0.388281f, 0.011112f, 0.625432f, -0.202532f, - -0.952190f, 0.661665f, 1.290380f, -0.625566f, -0.330132f, 0.377751f, 1.393908f, 0.947332f, 0.567214f, - 0.597034f, 0.789381f, 1.108524f, 0.989273f, 0.896032f, 0.972095f, 0.451968f, -0.186156f, 0.864871f, - 1.008577f, 1.059174f, 1.005235f, 0.834800f, 0.881400f, -0.345810f, 0.538783f, -0.242229f, 0.765357f, - 0.363634f, 0.540277f, 0.489711f, 0.556296f, 0.791247f, 0.963361f, 0.900796f, 1.274361f, 1.440297f, - 0.639664f, -0.769517f, 2.005213f, -0.205800f, 0.462482f, 0.893398f, -0.179109f, -0.385072f, 0.698468f, - 0.656636f, -0.167324f, 0.646567f, 0.534505f, 1.234794f, 1.110618f, 1.271695f, 0.759512f, 0.229293f, - 0.147224f, 0.794720f, 1.099447f, 1.113528f, 1.058541f, -0.208087f, 0.316237f, -0.032344f, -0.114418f, - 0.540560f, 0.498906f, 0.465116f, 0.418016f, 0.482087f, 0.445022f, 0.453282f, 0.438177f, -0.006379f, - 0.377702f, -0.855888f, 1.042157f, 0.408202f, 0.339785f, 0.287742f, 0.420788f, 0.465379f, 1.007626f, - 1.001159f, 0.554656f, 0.459783f, 1.143811f, 0.339036f, 0.714696f, 0.691498f, 0.735108f, 1.053392f, - 0.778748f, 0.068571f, 0.274017f, 1.481772f, 1.693937f, 0.526139f, 0.909311f, 0.350476f, 0.954506f, - 0.197028f, 0.923411f, 0.045156f, 0.957155f, 0.714096f, 0.633157f, 0.789485f, 0.581167f, 0.845790f, - 0.829842f, 1.194247f, 0.971378f, 1.019175f, 0.907585f, 0.953225f, 0.951858f, 1.102269f, 1.018174f, - 0.902432f, 0.841796f, -0.858393f, -0.330711f, -0.469070f, 0.464267f, 1.114611f, -1.004036f, 1.620967f, - 0.329466f, 0.139467f, -0.470611f, 0.308757f, 1.016010f, 0.453660f, 1.595124f, 0.558440f, 1.023249f, - 0.601039f, 1.007291f, 0.995676f, 0.637742f, 0.970108f, 0.851145f, 0.582246f, 0.840873f, 0.433405f, - -0.009376f, -0.395102f, 0.229559f, 1.179632f, 0.217997f, 0.145108f, 1.614064f, 1.010146f, 0.887566f, - -1.011727f, 0.264498f, 0.152422f, 0.570916f, 0.925334f, -0.269998f, 0.860524f, 1.051678f, 1.007595f, - 0.941741f, 0.488055f, 0.245246f, 0.227135f, 0.066780f, -0.402708f, 1.265329f, 0.257161f, -0.447346f, - 0.493756f, -0.268568f, -0.217773f, -0.301152f, 0.475332f, 0.373900f, 0.446225f, 0.471130f, 0.663021f, - 1.000752f, -0.090537f, 0.673516f, 0.781955f, 0.128213f, 1.239298f, 0.764475f, 1.281084f, 0.902059f, - 0.278935f, 0.221142f, 0.160415f, -0.106214f, 0.210654f, 0.141437f, 0.198334f, 0.149962f, 0.565323f, - 0.050416f, 0.888878f, 0.074347f, 0.079686f, -0.363394f, 0.253592f, -0.311712f, -0.291973f, 0.133119f, - 1.097622f, 0.962363f, 0.796541f, 0.851959f, 0.628367f, 0.626313f, 0.646783f, 0.138650f, 0.510147f, - 1.394106f, 0.600274f, 1.246940f, 0.872970f, 0.275462f, -0.508244f, -0.408690f, 1.314789f, 0.349021f, - 1.545499f, 0.153658f, 0.231785f, 0.389777f, 0.378070f, 0.840290f, -1.853665f, 1.786896f, 0.104429f, - 0.181189f, 0.667719f, 0.567943f, 0.718873f, 0.244843f, 1.129714f, 0.881495f, 1.460520f, 1.995885f, - -0.395025f, 0.817815f, 1.208726f, -1.411448f, 0.606279f, -0.143777f, 0.296987f, 1.422581f, 0.720905f, - 1.279913f, -0.352711f, 0.658642f, 1.613478f, 0.339589f, -0.089663f, 0.243404f, 1.226488f, 0.467706f, - 0.797042f, 0.442854f, 1.121590f, -0.153407f, 1.431477f, 0.230959f, 1.437285f, -0.046937f, -1.527740f, - -0.272532f, 0.732910f, 0.766692f, 0.749836f, 0.778544f, 1.502128f, -0.240678f, 0.820989f, 1.461264f, - 0.744201f, 0.593997f, 0.769196f, 0.670758f, -0.186752f, 1.864102f, -0.563369f, 2.274148f, 1.338321f, - 0.830787f, -0.191057f, 0.642745f, 1.092864f, 1.217034f, 1.076530f, 0.948315f}; + 0.406306f, 1.632045f, 0.095849f, 0.919355f, -0.458834f, 1.632483f, 0.876482f, 0.729815f, 0.750835f, + 0.782631f, 0.590117f, 1.476163f, 0.183714f, 0.057787f, -0.474648f, 0.143954f, 0.561618f, 0.031635f, + 0.426744f, 0.118848f, 0.054676f, 0.526575f, -0.827396f, -0.206514f, 0.631899f, 1.033381f, -0.028056f, + 0.117742f, -0.928939f, 0.254703f, 1.002641f, 0.056505f, 0.502409f, 0.186869f, -0.032152f, -0.201724f, + 0.683548f, 0.900928f, 0.126877f, 1.073324f, -0.017409f, 0.957481f, 0.710492f, 1.254686f, -0.620889f, + 0.882544f, 1.003820f, 0.385277f, 0.814893f, -0.841305f, 1.028838f, 1.664626f, 0.982238f, 0.150513f, + -0.461095f, -0.012286f, 1.094831f, 0.900296f, -0.437987f, 0.919201f, -0.604762f, 0.398245f, 1.126501f, + 1.388226f, 0.740287f, 0.352386f, 0.833504f, 0.614170f, 0.687727f, 0.626510f, 0.563813f, 0.408836f, + 0.651389f, 0.307533f, 1.158524f, 0.360064f, 0.588918f, 0.904664f, 0.418446f, 0.420879f, 0.495571f, + 0.796672f, 0.759542f, 0.996513f, -0.328335f, 1.636925f, -0.644444f, 1.350502f, 0.891792f, 0.600690f, + 0.795602f, 0.142066f, -0.015730f, -0.867947f, 1.039989f, 0.261774f, 1.182381f, 0.375100f, 0.493101f, + -0.112225f, 0.136779f, 1.225890f, 0.158450f, 1.142486f, -0.296101f, 0.228868f, 0.873088f, 0.397857f, + 0.432766f, 1.815673f, 0.353312f, -0.006854f, 0.251850f, 0.343477f, -0.497336f, 0.382225f, 0.758787f, + 0.117172f, 0.342274f, 0.892228f, -0.293386f, -1.206122f, 0.772336f, 1.964310f, 0.807267f, -0.553660f, + 1.500599f, 1.184999f, -0.397960f, 0.498094f, -0.040874f, 0.811189f, -0.472885f, 2.600490f, 0.192458f, + 1.023374f, 0.762038f, 1.098150f, -0.060817f, 0.078648f, 0.518953f, 0.044398f, 0.859423f, 0.038016f, + 0.176986f, 0.714081f, 0.648229f, -0.296623f, 1.040141f, 0.429690f, -0.494966f, 1.206059f, 0.709146f, + 1.134488f, 1.010104f, 0.117495f, 0.857719f, 0.187595f, -0.713799f, 0.068448f, 0.201653f, 0.252268f, + -0.172338f, 0.070818f, 1.228964f, 1.349056f, 0.173722f, 1.663625f, 0.077027f, 1.191751f, 0.094092f, + 1.179984f, -0.462022f, 0.831658f, 1.105588f, 0.249245f, 0.829719f, 1.360636f, 0.624319f, 1.011229f, + -0.634975f, 0.475544f, 0.034839f, 1.765260f, 0.660444f, 0.743193f, 0.795097f, 1.088295f, 0.300263f, + 0.476736f, 1.126446f, 0.453643f, 1.137416f, -0.572652f, 0.673148f, 1.159168f, 0.829472f, 0.160861f, + 0.424527f, 0.503895f, 1.131327f, 0.397239f, 1.178295f, 0.893187f, 1.147333f, 0.005007f, 0.581892f, + 1.174909f, 0.703488f, 0.937277f, 1.057870f, 1.042361f, 0.414783f, 1.554468f, -0.841804f, 1.139242f, + 0.742398f, 0.564714f, -0.081046f, 2.137630f, 0.467942f, 0.330163f, 0.807147f, 1.276604f, -0.094400f, + -0.540889f, 0.428099f, 0.338536f, -0.296175f, -0.883684f, -0.070659f, 0.765356f, -0.351806f, -0.067355f, + 1.118756f, 0.077982f, 0.446440f, -0.258010f, 0.252682f, 1.239576f, -0.979634f, 0.825275f, -0.463020f, + 0.125961f, -0.208515f, 1.494655f, 0.097464f, 0.432844f, 0.867343f, -0.536458f, 0.736790f, 0.328702f, + 0.819557f, 0.061518f, 0.591131f, 0.943027f, -0.514167f, 2.631821f, -0.116213f, 0.907785f, 0.237840f, + 2.037882f, 0.258945f, 0.554331f, 0.749937f, 1.132450f, 0.197422f, 0.606424f, -1.247748f, -0.002311f, + 1.839518f, 0.087527f, 0.521764f, -0.146106f, -0.980738f, -0.302773f, 0.676835f, -0.132461f, 0.596699f, + 0.617694f, 0.659876f, 0.765150f, 1.575500f, 0.117460f, -0.473908f, -0.423069f, -0.505049f, -0.037672f, + 0.447086f, 0.281287f, -0.018190f, 0.915389f, 1.207481f, -0.962214f, 1.016921f, 1.156551f, 0.019404f, + 0.709657f, 0.713726f, 1.354227f, 0.270004f, 0.840813f, 0.865947f, 0.102975f, 0.796979f, 0.520542f, + 1.122967f, -0.562413f, 1.328713f, 0.137686f, 1.895931f, -0.574054f, 0.856002f, 0.857834f, 0.983861f, + 0.978393f, 1.250703f, 0.584533f, 0.704302f, -0.218810f, -0.416660f, 1.397336f, 0.564530f, 0.582539f, + 0.895726f, 0.679485f, 0.442242f, 0.541062f, 0.109609f, 0.275143f, 0.107930f, 0.353517f, 0.707609f, + 0.915281f, 0.628682f, 0.355126f, 0.897993f, 0.923647f, 0.977992f, 0.935513f, -0.370250f, -0.000332f, + -0.462323f, 0.742310f, 0.634979f, 0.910902f, 1.059454f, 1.354346f, 1.166715f, 0.402587f, 1.013271f, + 0.793169f, 0.608214f, -0.429466f, 0.396397f, -0.096419f, 1.306138f, 0.732527f, -0.068210f, 0.480573f, + -0.084915f, 0.843406f, 1.124528f, -0.111570f, 0.667385f, 1.014872f, 1.221989f, 1.165116f, -1.026174f, + 1.364619f, 1.676924f, 0.592108f, 1.041303f, 0.342757f, 2.547432f, 0.978781f, 1.042198f, -0.103338f, + 0.761959f, -0.205143f, 0.651057f, 1.635670f, 0.429972f, 1.116617f, 0.121796f, 1.499508f, 0.477808f, + 1.004834f, 0.238391f, 1.527244f, 0.030314f, 0.851662f, 0.729685f, 0.833627f, 0.322326f, 1.746771f, + 1.284994f, -0.011233f, -0.003167f, 0.150784f, 0.258291f, 1.278590f, 0.351162f, 0.263747f, 0.240001f, + -0.496733f, 1.630098f, 0.959952f, 0.691867f, 0.781609f, 0.678993f, 0.117479f, 0.106319f, 0.737017f, + 0.054679f, 0.007170f, 0.031594f, 0.058290f, 0.042648f, 0.435231f, -0.069486f, 1.149115f, -0.284731f, + 0.478892f, -0.119480f, 1.305503f, -1.632824f, -0.186219f, 0.339929f, 0.911391f, 1.028848f, 0.301977f, + 0.828055f, -0.564568f, 1.414307f, 1.432279f, 0.605853f, 0.199995f, -0.442368f, 1.540784f, 0.876256f, + 0.771619f, -0.860229f, 1.000022f, 0.093107f, 0.450905f, 0.872359f, 2.159872f, 0.030324f, -0.212928f, + 0.691624f, 0.714844f, 0.749217f, -0.247668f, 0.546163f, 0.526861f, 0.965844f, 0.398844f, 0.808378f, + 0.411826f, 0.859227f, 1.148453f, 1.279391f, 0.239180f, 0.580433f, 1.115814f, 1.052553f, 0.938658f, + -0.629015f, 0.794400f, -0.489248f, 1.621988f, 0.789545f, 0.749079f, -0.024277f, 0.386545f, 0.105109f, + -0.943201f, 0.658228f, 0.981167f, 1.500447f, -0.074319f, 0.409342f, 1.247461f, -0.211410f, 0.188935f, + 0.095841f, 0.758850f, 0.595416f, 0.656214f, 0.905517f, -0.334851f, 0.295979f, 0.567667f, 0.073956f, + 0.349117f, 1.184909f, 0.027066f, 2.348872f, 1.470366f, -0.435509f, 1.778383f, -0.706661f, 0.577661f, + 0.928943f, -0.556358f, 0.781633f, 2.151912f, 0.761094f, -0.845767f, 0.154289f, 1.149119f, 0.798352f, + 0.738550f, 0.334614f, 0.879107f, 0.544361f, 1.241048f, -0.116848f, 0.680143f, 0.749525f, 0.903967f, + 0.521876f, 0.044571f, 0.636013f, 0.101543f, 1.160492f, -0.183245f, 0.374511f, 0.647929f, 0.272462f, + 0.221596f, 0.477392f, 0.293685f, 0.793616f, 0.434287f, 0.685965f, 0.952972f, -0.181046f, 0.117678f, + 1.207595f, -0.074850f, -0.407813f, -0.030066f, 0.048566f, 0.067439f, 0.001379f, -0.060172f, 0.280818f, + -0.846823f, 0.816884f, 0.685541f, 0.148579f, 1.228496f, 1.213856f, 0.082768f, -0.675884f, 0.850093f, + 1.127087f, 0.347097f, 0.837999f, 0.524987f, 1.104285f, -0.759169f, 0.635460f, 0.345593f, -0.202532f, + -0.625566f, 0.947332f, 1.108524f, 0.451968f, 1.059174f, -0.345810f, 0.363634f, 0.791247f, 1.440297f, + -0.205800f, -0.385072f, 0.646567f, 1.271695f, 0.794720f, -0.208087f, 0.540560f, 0.482087f, -0.006379f, + 0.408202f, 0.465379f, 0.459783f, 0.691498f, 0.068571f, 0.526139f, 0.197028f, 0.714096f, 0.845790f, + 1.019175f, 1.102269f, -0.858393f, 1.114611f, 0.139467f, 0.453660f, 0.601039f, 0.970108f, 0.433405f, + 1.179632f, 1.010146f, 0.152422f, 0.860524f, 0.488055f, -0.402708f, 0.493756f, 0.475332f, 0.663021f, + 0.781955f, 1.281084f, 0.160415f, 0.198334f, 0.888878f, 0.253592f, 1.097622f, 0.628367f, 0.510147f, + 0.872970f, 1.314789f, 0.231785f, -1.853665f, 0.667719f, 1.129714f, -0.395025f, 0.606279f, 0.720905f, + 1.613478f, 1.226488f, 1.121590f, 1.437285f, 0.732910f, 1.502128f, 0.744201f, -0.186752f, 1.338321f, + 1.092864f, -0.237060f, 2.343516f, 0.055512f, 0.903969f, 0.112821f, -0.874084f, 0.501252f, 0.883760f, + 0.825862f, 1.112298f, 0.830550f, 0.857847f, -0.774066f, 0.048746f, 0.173347f, 0.374310f, 0.508790f, + 0.892303f, 0.267807f, -0.501987f, -0.221902f, 0.168966f, 0.814484f, 0.951160f, 0.628689f, 1.171188f, + 1.143145f, 0.117596f, -0.374746f, -0.281733f, 1.347597f, 0.084762f, 0.563618f, 0.110448f, -0.044959f, + 0.876871f, 1.021856f, 1.680601f, 1.617434f, 1.269441f, 0.006120f, 0.766407f, -1.697397f, 1.538109f, + -0.396112f, 0.895946f, 0.915749f, 0.115267f, 0.798699f, -0.323668f, 0.707952f, 1.253857f, 1.336810f, + 0.388281f, -0.952190f, -0.330132f, 0.567214f, 0.989273f, -0.186156f, 1.005235f, 0.538783f, 0.540277f, + 0.963361f, 0.639664f, 0.462482f, 0.698468f, 0.534505f, 0.759512f, 1.099447f, 0.316237f, 0.498906f, + 0.445022f, 0.377702f, 0.339785f, 1.007626f, 1.143811f, 0.735108f, 0.274017f, 0.909311f, 0.923411f, + 0.633157f, 0.829842f, 0.907585f, 1.018174f, -0.330711f, -1.004036f, -0.470611f, 1.595124f, 1.007291f, + 0.851145f, -0.009376f, 0.217997f, 0.887566f, 0.570916f, 1.051678f, 0.245246f, 1.265329f, -0.268568f, + 0.373900f, 1.000752f, 0.128213f, 0.902059f, -0.106214f, 0.149962f, 0.074347f, -0.311712f, 0.962363f, + 0.626313f, 1.394106f, 0.275462f, 0.349021f, 0.389777f, 1.786896f, 0.567943f, 0.881495f, 0.817815f, + -0.143777f, 1.279913f, 0.339589f, 0.467706f, -0.153407f, -0.046937f, 0.766692f, -0.240678f, 0.593997f, + 1.864102f, 0.830787f, 1.217034f, -0.176123f, 0.595660f, 0.827656f, 0.861351f, -0.710248f, 1.412525f, + 0.737254f, 0.893155f, 0.796258f, 0.917900f, 0.020787f, 0.528776f, 0.896313f, -0.023476f, 0.010474f, + -0.061789f, 0.504972f, 0.727948f, -1.691226f, -0.209513f, 0.783358f, -0.402073f, 1.660988f, 0.398667f, + 0.746822f, 0.756122f, 1.075280f, 0.117522f, 0.482337f, 0.121187f, -0.097000f, 0.026081f, 0.564591f, + 0.229187f, -0.092778f, 0.715658f, 1.202137f, -0.648320f, 0.964561f, -0.294534f, -0.047344f, 1.191577f, + -0.162200f, 1.455440f, -0.230969f, 0.864577f, 1.022470f, 0.269888f, 0.830973f, 0.796731f, 1.288781f, + -0.279808f, 1.461457f, 0.011112f, 0.661665f, 0.377751f, 0.597034f, 0.896032f, 0.864871f, 0.834800f, + -0.242229f, 0.489711f, 0.900796f, -0.769517f, 0.893398f, 0.656636f, 1.234794f, 0.229293f, 1.113528f, + -0.032344f, 0.465116f, 0.453282f, -0.855888f, 0.287742f, 1.001159f, 0.339036f, 1.053392f, 1.481772f, + 0.350476f, 0.045156f, 0.789485f, 1.194247f, 0.953225f, 0.902432f, -0.469070f, 1.620967f, 0.308757f, + 0.558440f, 0.995676f, 0.582246f, -0.395102f, 0.145108f, -1.011727f, 0.925334f, 1.007595f, 0.227135f, + 0.257161f, -0.217773f, 0.446225f, -0.090537f, 1.239298f, 0.278935f, 0.210654f, 0.565323f, 0.079686f, + -0.291973f, 0.796541f, 0.646783f, 0.600274f, -0.508244f, 1.545499f, 0.378070f, 0.104429f, 0.718873f, + 1.460520f, 1.208726f, 0.296987f, -0.352711f, -0.089663f, 0.797042f, 1.431477f, -1.527740f, 0.749836f, + 0.820989f, 0.769196f, -0.563369f, -0.191057f, 1.076530f, 0.250859f, 0.857584f, -0.346350f, 0.894605f, + 0.886723f, 0.338762f, 0.656447f, 1.024669f, 0.693469f, 0.659168f, 0.905854f, 0.224581f, 0.357502f, + -0.007332f, -0.390864f, 0.307303f, 0.571301f, 0.437075f, -0.311736f, -0.249217f, 0.585570f, -0.397286f, + 1.381788f, -0.368342f, 0.647196f, 0.809376f, -0.462215f, 0.117660f, 0.453367f, -0.164130f, 0.558638f, + -0.054476f, 0.367439f, 0.244487f, -0.175144f, -0.005267f, 1.277564f, 0.465179f, 0.811168f, -0.541415f, + -0.034877f, 0.830097f, -0.216335f, 0.466843f, 0.311936f, 0.906072f, 1.086701f, 0.017932f, 0.843566f, + 0.882529f, 1.066624f, -0.810174f, -0.560630f, 0.625432f, 1.290380f, 1.393908f, 0.789381f, 0.972095f, + 1.008577f, 0.881400f, 0.765357f, 0.556296f, 1.274361f, 2.005213f, -0.179109f, -0.167324f, 1.110618f, + 0.147224f, 1.058541f, -0.114418f, 0.418016f, 0.438177f, 1.042157f, 0.420788f, 0.554656f, 0.714696f, + 0.778748f, 1.693937f, 0.954506f, 0.957155f, 0.581167f, 0.971378f, 0.951858f, 0.841796f, 0.464267f, + 0.329466f, 1.016010f, 1.023249f, 0.637742f, 0.840873f, 0.229559f, 1.614064f, 0.264498f, -0.269998f, + 0.941741f, 0.066780f, -0.447346f, -0.301152f, 0.471130f, 0.673516f, 0.764475f, 0.221142f, 0.141437f, + 0.050416f, -0.363394f, 0.133119f, 0.851959f, 0.138650f, 1.246940f, -0.408690f, 0.153658f, 0.840290f, + 0.181189f, 0.244843f, 1.995885f, -1.411448f, 1.422581f, 0.658642f, 0.243404f, 0.442854f, 0.230959f, + -0.272532f, 0.778544f, 1.461264f, 0.670758f, 2.274148f, 0.642745f, 0.948315f}; std::vector swish_data = { - 0.243866f, -0.159901f, -0.192410f, 0.064602f, 1.365124f, 0.309820f, 2.455177f, 1.225849f, 0.050220f, - -0.020019f, -0.054734f, 0.294918f, 0.657257f, 0.561637f, 0.646839f, 0.735547f, -0.177689f, -0.181556f, - 0.132996f, 0.133336f, 1.365588f, 2.420778f, 1.802949f, 1.254788f, 0.618877f, 0.105460f, 0.146142f, - 0.015386f, 0.492453f, 0.752824f, 0.352078f, 0.596943f, 0.510088f, 0.519554f, 0.509331f, 0.492345f, - 0.537078f, 0.823517f, 0.856461f, 0.581139f, 0.379677f, -0.029484f, 0.108424f, 0.186914f, 1.201586f, - 0.040870f, 0.392432f, 1.487454f, 0.100271f, 0.325333f, -0.278360f, 1.006534f, 0.029728f, 0.022692f, - -0.001154f, -0.005585f, -0.182035f, 0.603779f, 1.587304f, -0.001581f, 0.077149f, 0.019369f, 0.045678f, - 0.081065f, 0.357653f, 0.096304f, 0.327438f, 0.145732f, 0.016068f, 0.479364f, -0.067726f, 1.000125f, - 0.258221f, 0.425634f, -0.267492f, 0.206097f, 0.062951f, -0.126475f, -0.128642f, 0.149164f, 0.028085f, - 0.768537f, 0.448763f, 0.134332f, 0.331049f, 0.260306f, -0.061851f, -0.187918f, -0.251691f, -0.187456f, - 0.384811f, 1.363060f, -0.092633f, 0.928184f, 0.401312f, 0.694153f, 0.412580f, 0.475279f, 0.435012f, - 0.461047f, 0.762192f, 0.858428f, 0.522193f, 0.536204f, -0.013831f, 0.740447f, 1.305406f, 0.450521f, - 0.062333f, 0.062195f, 0.062175f, 0.062186f, -0.263020f, 0.602277f, -0.181835f, 0.055983f, 0.143483f, - 0.102570f, -0.167443f, 0.498477f, 0.733510f, -0.234669f, -0.190078f, 0.028087f, 0.029050f, 0.035395f, - -0.018481f, 0.003598f, 0.313013f, 0.110958f, 0.272698f, 0.016046f, 0.102139f, 0.141960f, 0.160295f, - 0.029994f, -0.015817f, -0.078762f, -0.009012f, 0.021779f, -0.090723f, 0.036662f, 0.653681f, 0.264239f, - 0.454239f, 0.950773f, 0.929582f, -0.033536f, 0.640686f, 1.071117f, -0.265990f, 0.872580f, 0.067457f, - 0.094387f, 0.746799f, -0.122234f, 0.799871f, 1.398649f, 0.879794f, 0.295709f, -0.008629f, 0.039996f, - 0.009796f, -0.056175f, 0.691892f, 0.914138f, 0.475701f, 1.027117f, 0.476392f, 0.049258f, 0.479070f, - -0.266875f, 0.976283f, 0.902623f, 1.076367f, -0.084465f, -0.217050f, -0.178574f, 0.153117f, 0.198578f, - 0.624266f, 0.579420f, 0.587422f, 0.650082f, 0.734605f, 0.830635f, 0.609541f, 0.757945f, 0.229296f, - 0.140073f, 0.054136f, 0.173615f, 0.564844f, 0.577730f, 0.549380f, 0.576280f, -0.253452f, 1.082880f, - 0.326523f, -0.204651f, 0.757935f, 0.406556f, 0.847322f, 1.137731f, 1.399714f, 0.741494f, -0.204150f, - 1.156216f, 0.714629f, -0.219945f, 1.050518f, 0.391983f, 0.080910f, 0.293266f, 0.073575f, 0.109964f, - -0.178318f, 0.017723f, 1.648380f, -0.173044f, -0.006105f, 1.507298f, -0.206833f, 1.268957f, 0.820346f, - 0.435470f, 0.600764f, 0.618676f, 0.640120f, 0.503657f, 0.602378f, 0.527688f, -0.171787f, 0.547762f, - 0.716126f, -0.255739f, 0.657118f, 0.814111f, 0.711085f, 0.731079f, -0.213635f, 0.172503f, 0.972323f, - 0.048719f, 0.238256f, 0.294135f, 0.375335f, 0.275437f, 0.850725f, 0.850672f, 0.471277f, 0.615220f, - 1.111010f, 0.277405f, -0.097483f, 1.936515f, 0.501217f, 0.861257f, -0.165546f, 0.015392f, 0.206920f, - -0.206513f, 1.120330f, -0.095172f, 0.581032f, 0.445763f, 0.359888f, 0.460849f, 0.398530f, 0.882337f, - 0.373787f, 0.479997f, 0.457656f, 0.577514f, 0.636029f, 0.508724f, 0.408295f, 0.086886f, 0.450923f, - -0.108577f, 0.359337f, 0.256654f, 0.269234f, 0.345855f, 0.245632f, 0.314115f, 0.341984f, 0.331264f, - 0.428173f, 0.855378f, 0.057805f, 0.699551f, 0.177226f, 0.237559f, 0.156379f, 0.238672f, 0.881711f, - 0.900973f, 0.056874f, 0.559207f, 0.212098f, 0.633758f, 0.207681f, 0.247724f, 0.378743f, 0.870852f, - 0.474008f, 0.603606f, 0.644037f, 0.002510f, 0.653584f, 0.871938f, 0.252370f, 0.373285f, 0.410021f, - 1.000926f, 0.254082f, 0.897667f, 0.208764f, 0.133824f, 0.307957f, 0.470606f, 0.638058f, 0.372154f, - 0.549116f, 0.673480f, 0.661132f, 0.840444f, 0.517441f, 0.785239f, 0.710716f, 0.780221f, 0.727826f, - 0.770623f, 0.671878f, 0.674734f, -0.137456f, 0.249797f, -0.151240f, -0.218730f, 1.370297f, 1.283304f, - -0.000166f, 0.547163f, -0.221845f, -0.253514f, -0.178658f, -0.185949f, 1.072585f, 0.863022f, 0.502916f, - 1.354473f, 0.632512f, 0.502989f, 0.415034f, 0.542996f, 0.387934f, 0.360029f, 0.649641f, 0.508608f, - 0.548196f, -0.038882f, 0.786735f, -0.011991f, 0.076070f, 1.912126f, 1.076488f, 0.230168f, -0.007803f, - 0.287736f, 0.889679f, 0.055314f, -0.256636f, 0.192089f, 0.241274f, -0.264336f, 0.768393f, 0.558143f, - 0.743397f, 0.433681f, 0.147922f, 0.998140f, 0.546106f, 0.713642f, 0.904965f, -0.044974f, 0.393839f, - 1.226827f, 0.222318f, -0.199037f, -0.169319f, -0.035779f, 0.306135f, 0.259179f, 0.236975f, 0.245986f, - -0.052967f, 0.197649f, -0.045887f, 0.969103f, 0.073060f, -0.126316f, 1.027756f, -0.094573f, 0.947734f, - -0.258402f, 0.494719f, 0.103365f, 0.085489f, -0.034082f, -0.032942f, 0.050215f, 0.866159f, 0.522367f, - 0.296938f, 0.516856f, -0.126290f, -0.145276f, -0.040656f, 0.383809f, 0.127472f, -0.032544f, 0.589695f, - 0.432058f, 0.615866f, 0.843271f, 0.848825f, 0.644802f, 0.237987f, 0.040511f, -0.052676f, -0.139653f, - 0.262488f, 0.272236f, 0.441086f, 0.169732f, 1.561563f, -0.112454f, 0.744888f, 0.362299f, 0.207543f, - 0.142219f, 0.943881f, 0.038345f, -0.003415f, 0.961279f, 0.888123f, 0.204724f, 0.141699f, -0.267405f, - -0.270732f, 0.907438f, 0.200946f, 0.573859f, 1.086931f, 0.013716f, -0.188076f, -0.178851f, 1.412803f, - 2.144154f, 0.227198f, 0.066942f, 0.381228f, 1.195574f, 0.516802f, -0.093427f, 0.769628f, -0.171073f, - 0.062015f, 1.220799f, 0.200465f, 1.521402f, 0.200143f, 0.051105f, 2.362491f, -0.233436f, 0.632902f, - 0.262543f, 0.711443f, 0.370009f, -0.125327f, 0.610777f, 0.770470f, 0.665922f, -0.277876f, -0.197959f, - -0.049002f, -0.202732f, 0.528298f, 0.498287f, 0.519488f, 0.536225f, 1.722697f, 0.191122f, -0.092087f, - 1.927784f, 0.558246f, 0.568889f, 0.427906f, 0.518755f, -0.202095f, 0.031705f, 1.368965f, -0.254002f, - 1.226985f, 0.380467f, 0.260506f, 0.083084f, 0.907526f, 0.678707f, 0.841215f, 0.872584f, 0.550561f, - -0.104546f, -0.080327f, 0.141080f, 0.499761f, 2.138265f, 0.384000f, 0.602158f, 0.195041f, 0.028526f, - 0.575932f, -0.143482f, 0.621209f, 0.643413f, 0.605480f, 0.635026f, 0.344486f, 0.059589f, -0.234058f, - 0.627989f, 0.962738f, -0.257335f, 1.135901f, 0.197799f, -0.055014f, 0.312156f, 0.498675f, 0.432245f, - 0.451459f, 0.625349f, 0.633730f, 0.754035f, 0.508984f, 0.574370f, 0.548760f, 0.462362f, 0.643412f, - 0.837068f, 0.655944f, 0.434440f, 0.327522f, 0.578454f, 0.010501f, 0.645105f, 0.022782f, 0.602389f, - 0.332705f, 0.124847f, 0.415858f, -0.244295f, 0.636554f, 0.210367f, 0.053347f, 0.024967f, -0.011600f, - -0.003652f, 0.883625f, 0.094167f, 0.005264f, -0.157717f, -0.083251f, 0.221779f, -0.029940f, 0.177076f, - 0.221916f, 0.317751f, 0.314914f, 0.365097f, 0.425394f, 0.632969f, 0.490896f, 0.265550f, 0.154676f, - 0.151727f, -0.263180f, -0.131768f, 0.123024f, -0.189286f, -0.093822f, -0.109161f, 0.294614f, -0.098691f, - 0.537700f, 0.376140f, 0.168252f, 0.091604f, -0.161157f, -0.159695f, 0.546489f, 0.564490f, 1.395845f, - 1.104433f, 0.263568f, 0.686118f, 0.238550f, -0.150630f, 0.456214f, 0.410026f, 0.506708f, 0.424806f, - 0.687772f, 0.894037f, 0.514549f, 0.560069f, -0.082351f, 0.866797f, 0.801729f, -0.178628f, 0.062297f, - 0.062251f, 0.062210f, 0.062287f, 0.929695f, -0.152669f, 0.298229f, 0.277207f, -0.036025f, -0.121153f, - 0.064261f, -0.075345f, -0.162895f, 1.069637f, -0.046150f, 0.355371f, -0.014807f, 0.044176f, 0.013210f, - -0.026496f, 0.024873f, 0.359187f, 0.359935f, 0.217098f, 0.034856f, 0.058271f, 0.127668f, 0.137113f, - 0.000690f, -0.021974f, -0.044238f, -0.079923f, -0.029181f, 0.619223f, 0.480672f, -0.002627f, 0.159995f, - 0.751405f, 0.924329f, 0.999099f, -0.254131f, 1.416720f, -0.222613f, 0.285733f, 0.566570f, 1.349653f, - 0.698375f, 0.561619f, 0.455867f, 0.990985f, -0.125735f, -0.199164f, 0.079798f, 0.003070f, -0.023112f, - -0.017134f, 0.950309f, 0.523259f, 0.913967f, 0.578059f, 0.935859f, -0.262766f, -0.074537f, -0.096513f, - 0.043096f, 1.266156f, 1.180120f, 0.286938f, -0.227895f, -0.159335f, -0.102207f, 0.180099f, 0.595564f, - 0.636225f, 0.608330f, 0.645301f, 0.851290f, 0.654005f, 0.751979f, 0.812592f, 0.203369f, 0.060952f, - 0.153044f, 0.009046f, 0.584960f, 0.550860f, 0.578823f, 0.589835f, 0.329856f, -0.135870f, 0.549166f, - 0.624253f, 0.829387f, 0.474291f, 1.010328f, 0.793519f, -0.242043f, 0.975459f, -0.120458f, -0.249415f, - 0.415417f, 1.058707f, 1.186345f, -0.203734f, 0.202362f, 0.231365f, 0.005587f, 0.407439f, -0.091046f, - -0.265132f, 0.436457f, 1.011931f, -0.218020f, -0.138064f, 0.224131f, 1.116821f, 0.682626f, 0.361950f, - 0.385073f, 0.542856f, 0.833448f, 0.721125f, 0.636303f, 0.705290f, 0.276201f, -0.084439f, 0.608590f, - 0.739027f, 0.786472f, 0.735919f, 0.582164f, 0.623249f, -0.143303f, 0.340258f, -0.106517f, 0.522368f, - 0.214515f, 0.341388f, 0.303639f, 0.353579f, 0.544456f, 0.697275f, 0.640568f, 0.995898f, 1.164481f, - 0.418773f, -0.243616f, 1.767281f, -0.092349f, 0.283780f, 0.633947f, -0.081556f, -0.155917f, 0.466470f, - 0.432398f, -0.076679f, 0.424301f, 0.337023f, 0.956541f, 0.835456f, 0.993236f, 0.517416f, 0.127733f, - 0.079021f, 0.547438f, 0.824758f, 0.838249f, 0.785873f, -0.093257f, 0.182914f, -0.015911f, -0.053940f, - 0.341603f, 0.310421f, 0.285687f, 0.252067f, 0.298046f, 0.271221f, 0.277146f, 0.266335f, -0.003180f, - 0.224097f, -0.255225f, 0.770431f, 0.245189f, 0.198482f, 0.164428f, 0.254018f, 0.285878f, 0.738142f, - 0.732134f, 0.352326f, 0.281830f, 0.867442f, 0.197982f, 0.479874f, 0.460745f, 0.496878f, 0.781012f, - 0.533761f, 0.035461f, 0.155663f, 1.207408f, 1.430939f, 0.330722f, 0.648210f, 0.205636f, 0.689173f, - 0.108188f, 0.660919f, 0.023088f, 0.691594f, 0.479376f, 0.413581f, 0.542946f, 0.372724f, 0.591785f, - 0.577837f, 0.916584f, 0.704632f, 0.748902f, 0.646659f, 0.688003f, 0.686755f, 0.827456f, 0.747968f, - 0.642034f, 0.588283f, -0.255522f, -0.138260f, -0.180515f, 0.285072f, 0.839288f, -0.269231f, 1.353391f, - 0.191627f, 0.074588f, -0.180937f, 0.178024f, 0.745949f, 0.277417f, 1.326083f, 0.355219f, 0.752707f, - 0.388207f, 0.737830f, 0.727050f, 0.417238f, 0.703465f, 0.596488f, 0.373560f, 0.587475f, 0.262941f, - -0.004666f, -0.159025f, 0.127896f, 0.902279f, 0.120832f, 0.077809f, 1.346089f, 0.740486f, 0.628741f, - -0.269769f, 0.149638f, 0.082008f, 0.364801f, 0.662657f, -0.116884f, 0.604751f, 0.779395f, 0.738113f, - 0.677536f, 0.302422f, 0.137584f, 0.126410f, 0.034505f, -0.161350f, 0.986884f, 0.145023f, -0.174461f, - 0.306618f, -0.116360f, -0.097077f, -0.128073f, 0.293111f, 0.221499f, 0.272082f, 0.290052f, 0.437553f, - 0.731756f, -0.043221f, 0.446063f, 0.536501f, 0.068210f, 0.961003f, 0.521620f, 1.002620f, 0.641700f, - 0.158794f, 0.122747f, 0.086627f, -0.050289f, 0.116380f, 0.075711f, 0.108969f, 0.080593f, 0.360497f, - 0.025843f, 0.629911f, 0.038555f, 0.041430f, -0.149042f, 0.142787f, -0.131760f, -0.124825f, 0.070983f, - 0.823012f, 0.696361f, 0.549003f, 0.597205f, 0.409770f, 0.408138f, 0.424474f, 0.074123f, 0.318761f, - 1.117023f, 0.387609f, 0.968585f, 0.615761f, 0.156582f, -0.190899f, -0.163160f, 1.036466f, 0.204659f, - 1.273897f, 0.082720f, 0.129264f, 0.232396f, 0.224349f, 0.586964f, -0.251066f, 1.530559f, 0.054939f, - 0.098779f, 0.441357f, 0.362511f, 0.483341f, 0.137334f, 0.853822f, 0.623333f, 1.185376f, 1.757106f, - -0.159001f, 0.567378f, 0.930808f, -0.276652f, 0.392318f, -0.066730f, 0.170383f, 1.146234f, 0.485029f, - 1.001448f, -0.145573f, 0.434016f, 1.345469f, 0.198351f, -0.042823f, 0.136440f, 0.948325f, 0.287564f, - 0.549434f, 0.269671f, 0.845997f, -0.070832f, 1.155390f, 0.128756f, 1.161375f, -0.022918f, -0.272434f, - -0.117812f, 0.495040f, 0.523501f, 0.509246f, 0.533588f, 1.228578f, -0.105927f, 0.570132f, 1.186146f, - 0.504504f, 0.382702f, 0.525628f, 0.443822f, -0.084682f, 1.613891f, -0.204372f, 2.062000f, 1.060236f, - 0.578661f, -0.086430f, 0.421238f, 0.818468f, 0.938992f, 0.802915f, 0.683523f}; + 0.243866f, 1.365124f, 0.050220f, 0.657257f, -0.177689f, 1.365588f, 0.618877f, 0.492453f, 0.510088f, + 0.537078f, 0.379677f, 1.201586f, 0.100271f, 0.029728f, -0.182035f, 0.077149f, 0.357653f, 0.016068f, + 0.258221f, 0.062951f, 0.028085f, 0.331049f, -0.251691f, -0.092633f, 0.412580f, 0.762192f, -0.013831f, + 0.062333f, -0.263020f, 0.143483f, 0.733510f, 0.029050f, 0.313013f, 0.102139f, -0.015817f, -0.090723f, + 0.454239f, 0.640686f, 0.067457f, 0.799871f, -0.008629f, 0.691892f, 0.476392f, 0.976283f, -0.217050f, + 0.624266f, 0.734605f, 0.229296f, 0.564844f, -0.253452f, 0.757935f, 1.399714f, 0.714629f, 0.080910f, + -0.178318f, -0.006105f, 0.820346f, 0.640120f, -0.171787f, 0.657118f, -0.213635f, 0.238256f, 0.850725f, + 1.111010f, 0.501217f, 0.206920f, 0.581032f, 0.398530f, 0.457656f, 0.408295f, 0.359337f, 0.245632f, + 0.428173f, 0.177226f, 0.881711f, 0.212098f, 0.378743f, 0.644037f, 0.252370f, 0.254082f, 0.307957f, + 0.549116f, 0.517441f, 0.727826f, -0.137456f, 1.370297f, -0.221845f, 1.072585f, 0.632512f, 0.387934f, + 0.548196f, 0.076070f, -0.007803f, -0.256636f, 0.768393f, 0.147922f, 0.904965f, 0.222318f, 0.306135f, + -0.052967f, 0.073060f, 0.947734f, 0.085489f, 0.866159f, -0.126290f, 0.127472f, 0.615866f, 0.237987f, + 0.262488f, 1.561563f, 0.207543f, -0.003415f, 0.141699f, 0.200946f, -0.188076f, 0.227198f, 0.516802f, + 0.062015f, 0.200143f, 0.632902f, -0.125327f, -0.277876f, 0.528298f, 1.722697f, 0.558246f, -0.202095f, + 1.226985f, 0.907526f, -0.159901f, 0.309820f, -0.020019f, 0.561637f, -0.181556f, 2.420778f, 0.105460f, + 0.752824f, 0.519554f, 0.823517f, -0.029484f, 0.040870f, 0.325333f, 0.022692f, 0.603779f, 0.019369f, + 0.096304f, 0.479364f, 0.425634f, -0.126475f, 0.768537f, 0.260306f, -0.187456f, 0.928184f, 0.475279f, + 0.858428f, 0.740447f, 0.062195f, 0.602277f, 0.102570f, -0.234669f, 0.035395f, 0.110958f, 0.141960f, + -0.078762f, 0.036662f, 0.950773f, 1.071117f, 0.094387f, 1.398649f, 0.039996f, 0.914138f, 0.049258f, + 0.902623f, -0.178574f, 0.579420f, 0.830635f, 0.140073f, 0.577730f, 1.082880f, 0.406556f, 0.741494f, + -0.219945f, 0.293266f, 0.017723f, 1.507298f, 0.435470f, 0.503657f, 0.547762f, 0.814111f, 0.172503f, + 0.294135f, 0.850672f, 0.277405f, 0.861257f, -0.206513f, 0.445763f, 0.882337f, 0.577514f, 0.086886f, + 0.256654f, 0.314115f, 0.855378f, 0.237559f, 0.900973f, 0.633758f, 0.870852f, 0.002510f, 0.373285f, + 0.897667f, 0.470606f, 0.673480f, 0.785239f, 0.770623f, 0.249797f, 1.283304f, -0.253514f, 0.863022f, + 0.502989f, 0.360029f, -0.038882f, 1.912126f, 0.287736f, 0.192089f, 0.558143f, 0.998140f, -0.044974f, + -0.199037f, 0.259179f, 0.197649f, -0.126316f, -0.258402f, -0.034082f, 0.522367f, -0.145276f, -0.032544f, + 0.843271f, 0.040511f, 0.272236f, -0.112454f, 0.142219f, 0.961279f, -0.267405f, 0.573859f, -0.178851f, + 0.066942f, -0.093427f, 1.220799f, 0.051105f, 0.262543f, 0.610777f, -0.197959f, 0.498287f, 0.191122f, + 0.568889f, 0.031705f, 0.380467f, 0.678707f, -0.192410f, 2.455177f, -0.054734f, 0.646839f, 0.132996f, + 1.802949f, 0.146142f, 0.352078f, 0.509331f, 0.856461f, 0.108424f, 0.392432f, -0.278360f, -0.001154f, + 1.587304f, 0.045678f, 0.327438f, -0.067726f, -0.267492f, -0.128642f, 0.448763f, -0.061851f, 0.384811f, + 0.401312f, 0.435012f, 0.522193f, 1.305406f, 0.062175f, -0.181835f, -0.167443f, -0.190078f, -0.018481f, + 0.272698f, 0.160295f, -0.009012f, 0.653681f, 0.929582f, -0.265990f, 0.746799f, 0.879794f, 0.009796f, + 0.475701f, 0.479070f, 1.076367f, 0.153117f, 0.587422f, 0.609541f, 0.054136f, 0.549380f, 0.326523f, + 0.847322f, -0.204150f, 1.050518f, 0.073575f, 1.648380f, -0.206833f, 0.600764f, 0.602378f, 0.716126f, + 0.711085f, 0.972323f, 0.375335f, 0.471277f, -0.097483f, -0.165546f, 1.120330f, 0.359888f, 0.373787f, + 0.636029f, 0.450923f, 0.269234f, 0.341984f, 0.057805f, 0.156379f, 0.056874f, 0.207681f, 0.474008f, + 0.653584f, 0.410021f, 0.208764f, 0.638058f, 0.661132f, 0.710716f, 0.671878f, -0.151240f, -0.000166f, + -0.178658f, 0.502916f, 0.415034f, 0.649641f, 0.786735f, 1.076488f, 0.889679f, 0.241274f, 0.743397f, + 0.546106f, 0.393839f, -0.169319f, 0.236975f, -0.045887f, 1.027756f, 0.494719f, -0.032942f, 0.296938f, + -0.040656f, 0.589695f, 0.848825f, -0.052676f, 0.441086f, 0.744888f, 0.943881f, 0.888123f, -0.270732f, + 1.086931f, 1.412803f, 0.381228f, 0.769628f, 0.200465f, 2.362491f, 0.711443f, 0.770470f, -0.049002f, + 0.519488f, -0.092087f, 0.427906f, 1.368965f, 0.260506f, 0.841215f, 0.064602f, 1.225849f, 0.294918f, + 0.735547f, 0.133336f, 1.254788f, 0.015386f, 0.596943f, 0.492345f, 0.581139f, 0.186914f, 1.487454f, + 1.006534f, -0.005585f, -0.001581f, 0.081065f, 0.145732f, 1.000125f, 0.206097f, 0.149164f, 0.134332f, + -0.187918f, 1.363060f, 0.694153f, 0.461047f, 0.536204f, 0.450521f, 0.062186f, 0.055983f, 0.498477f, + 0.028087f, 0.003598f, 0.016046f, 0.029994f, 0.021779f, 0.264239f, -0.033536f, 0.872580f, -0.122234f, + 0.295709f, -0.056175f, 1.027117f, -0.266875f, -0.084465f, 0.198578f, 0.650082f, 0.757945f, 0.173615f, + 0.576280f, -0.204651f, 1.137731f, 1.156216f, 0.391983f, 0.109964f, -0.173044f, 1.268957f, 0.618676f, + 0.527688f, -0.255739f, 0.731079f, 0.048719f, 0.275437f, 0.615220f, 1.936515f, 0.015392f, -0.095172f, + 0.460849f, 0.479997f, 0.508724f, -0.108577f, 0.345855f, 0.331264f, 0.699551f, 0.238672f, 0.559207f, + 0.247724f, 0.603606f, 0.871938f, 1.000926f, 0.133824f, 0.372154f, 0.840444f, 0.780221f, 0.674734f, + -0.218730f, 0.547163f, -0.185949f, 1.354473f, 0.542996f, 0.508608f, -0.011991f, 0.230168f, 0.055314f, + -0.264336f, 0.433681f, 0.713642f, 1.226827f, -0.035779f, 0.245986f, 0.969103f, -0.094573f, 0.103365f, + 0.050215f, 0.516856f, 0.383809f, 0.432058f, 0.644802f, -0.139653f, 0.169732f, 0.362299f, 0.038345f, + 0.204724f, 0.907438f, 0.013716f, 2.144154f, 1.195574f, -0.171073f, 1.521402f, -0.233436f, 0.370009f, + 0.665922f, -0.202732f, 0.536225f, 1.927784f, 0.518755f, -0.254002f, 0.083084f, 0.872584f, 0.550561f, + 0.499761f, 0.195041f, 0.621209f, 0.344486f, 0.962738f, -0.055014f, 0.451459f, 0.508984f, 0.643412f, + 0.327522f, 0.022782f, 0.415858f, 0.053347f, 0.883625f, -0.083251f, 0.221916f, 0.425394f, 0.154676f, + 0.123024f, 0.294614f, 0.168252f, 0.546489f, 0.263568f, 0.456214f, 0.687772f, -0.082351f, 0.062297f, + 0.929695f, -0.036025f, -0.162895f, -0.014807f, 0.024873f, 0.034856f, 0.000690f, -0.029181f, 0.159995f, + -0.254131f, 0.566570f, 0.455867f, 0.079798f, 0.950309f, 0.935859f, 0.043096f, -0.227895f, 0.595564f, + 0.851290f, 0.203369f, 0.584960f, 0.329856f, 0.829387f, -0.242043f, 0.415417f, 0.202362f, -0.091046f, + -0.218020f, 0.682626f, 0.833448f, 0.276201f, 0.786472f, -0.143303f, 0.214515f, 0.544456f, 1.164481f, + -0.092349f, -0.155917f, 0.424301f, 0.993236f, 0.547438f, -0.093257f, 0.341603f, 0.298046f, -0.003180f, + 0.245189f, 0.285878f, 0.281830f, 0.460745f, 0.035461f, 0.330722f, 0.108188f, 0.479376f, 0.591785f, + 0.748902f, 0.827456f, -0.255522f, 0.839288f, 0.074588f, 0.277417f, 0.388207f, 0.703465f, 0.262941f, + 0.902279f, 0.740486f, 0.082008f, 0.604751f, 0.302422f, -0.161350f, 0.306618f, 0.293111f, 0.437553f, + 0.536501f, 1.002620f, 0.086627f, 0.108969f, 0.629911f, 0.142787f, 0.823012f, 0.409770f, 0.318761f, + 0.615761f, 1.036466f, 0.129264f, -0.251066f, 0.441357f, 0.853822f, -0.159001f, 0.392318f, 0.485029f, + 1.345469f, 0.948325f, 0.845997f, 1.161375f, 0.495040f, 1.228578f, 0.504504f, -0.084682f, 1.060236f, + 0.818468f, -0.104546f, 2.138265f, 0.028526f, 0.643413f, 0.059589f, -0.257335f, 0.312156f, 0.625349f, + 0.574370f, 0.837068f, 0.578454f, 0.602389f, -0.244295f, 0.024967f, 0.094167f, 0.221779f, 0.317751f, + 0.632969f, 0.151727f, -0.189286f, -0.098691f, 0.091604f, 0.564490f, 0.686118f, 0.410026f, 0.894037f, + 0.866797f, 0.062251f, -0.152669f, -0.121153f, 1.069637f, 0.044176f, 0.359187f, 0.058271f, -0.021974f, + 0.619223f, 0.751405f, 1.416720f, 1.349653f, 0.990985f, 0.003070f, 0.523259f, -0.262766f, 1.266156f, + -0.159335f, 0.636225f, 0.654005f, 0.060952f, 0.550860f, -0.135870f, 0.474291f, 0.975459f, 1.058707f, + 0.231365f, -0.265132f, -0.138064f, 0.361950f, 0.721125f, -0.084439f, 0.735919f, 0.340258f, 0.341388f, + 0.697275f, 0.418773f, 0.283780f, 0.466470f, 0.337023f, 0.517416f, 0.824758f, 0.182914f, 0.310421f, + 0.271221f, 0.224097f, 0.198482f, 0.738142f, 0.867442f, 0.496878f, 0.155663f, 0.648210f, 0.660919f, + 0.413581f, 0.577837f, 0.646659f, 0.747968f, -0.138260f, -0.269231f, -0.180937f, 1.326083f, 0.737830f, + 0.596488f, -0.004666f, 0.120832f, 0.628741f, 0.364801f, 0.779395f, 0.137584f, 0.986884f, -0.116360f, + 0.221499f, 0.731756f, 0.068210f, 0.641700f, -0.050289f, 0.080593f, 0.038555f, -0.131760f, 0.696361f, + 0.408138f, 1.117023f, 0.156582f, 0.204659f, 0.232396f, 1.530559f, 0.362511f, 0.623333f, 0.567378f, + -0.066730f, 1.001448f, 0.198351f, 0.287564f, -0.070832f, -0.022918f, 0.523501f, -0.105927f, 0.382702f, + 1.613891f, 0.578661f, 0.938992f, -0.080327f, 0.384000f, 0.575932f, 0.605480f, -0.234058f, 1.135901f, + 0.498675f, 0.633730f, 0.548760f, 0.655944f, 0.010501f, 0.332705f, 0.636554f, -0.011600f, 0.005264f, + -0.029940f, 0.314914f, 0.490896f, -0.263180f, -0.093822f, 0.537700f, -0.161157f, 1.395845f, 0.238550f, + 0.506708f, 0.514549f, 0.801729f, 0.062210f, 0.298229f, 0.064261f, -0.046150f, 0.013210f, 0.359935f, + 0.127668f, -0.044238f, 0.480672f, 0.924329f, -0.222613f, 0.698375f, -0.125735f, -0.023112f, 0.913967f, + -0.074537f, 1.180120f, -0.102207f, 0.608330f, 0.751979f, 0.153044f, 0.578823f, 0.549166f, 1.010328f, + -0.120458f, 1.186345f, 0.005587f, 0.436457f, 0.224131f, 0.385073f, 0.636303f, 0.608590f, 0.582164f, + -0.106517f, 0.303639f, 0.640568f, -0.243616f, 0.633947f, 0.432398f, 0.956541f, 0.127733f, 0.838249f, + -0.015911f, 0.285687f, 0.277146f, -0.255225f, 0.164428f, 0.732134f, 0.197982f, 0.781012f, 1.207408f, + 0.205636f, 0.023088f, 0.542946f, 0.916584f, 0.688003f, 0.642034f, -0.180515f, 1.353391f, 0.178024f, + 0.355219f, 0.727050f, 0.373560f, -0.159025f, 0.077809f, -0.269769f, 0.662657f, 0.738113f, 0.126410f, + 0.145023f, -0.097077f, 0.272082f, -0.043221f, 0.961003f, 0.158794f, 0.116380f, 0.360497f, 0.041430f, + -0.124825f, 0.549003f, 0.424474f, 0.387609f, -0.190899f, 1.273897f, 0.224349f, 0.054939f, 0.483341f, + 1.185376f, 0.930808f, 0.170383f, -0.145573f, -0.042823f, 0.549434f, 1.155390f, -0.272434f, 0.509246f, + 0.570132f, 0.525628f, -0.204372f, -0.086430f, 0.802915f, 0.141080f, 0.602158f, -0.143482f, 0.635026f, + 0.627989f, 0.197799f, 0.432245f, 0.754035f, 0.462362f, 0.434440f, 0.645105f, 0.124847f, 0.210367f, + -0.003652f, -0.157717f, 0.177076f, 0.365097f, 0.265550f, -0.131768f, -0.109161f, 0.376140f, -0.159695f, + 1.104433f, -0.150630f, 0.424806f, 0.560069f, -0.178628f, 0.062287f, 0.277207f, -0.075345f, 0.355371f, + -0.026496f, 0.217098f, 0.137113f, -0.079923f, -0.002627f, 0.999099f, 0.285733f, 0.561619f, -0.199164f, + -0.017134f, 0.578059f, -0.096513f, 0.286938f, 0.180099f, 0.645301f, 0.812592f, 0.009046f, 0.589835f, + 0.624253f, 0.793519f, -0.249415f, -0.203734f, 0.407439f, 1.011931f, 1.116821f, 0.542856f, 0.705290f, + 0.739027f, 0.623249f, 0.522368f, 0.353579f, 0.995898f, 1.767281f, -0.081556f, -0.076679f, 0.835456f, + 0.079021f, 0.785873f, -0.053940f, 0.252067f, 0.266335f, 0.770431f, 0.254018f, 0.352326f, 0.479874f, + 0.533761f, 1.430939f, 0.689173f, 0.691594f, 0.372724f, 0.704632f, 0.686755f, 0.588283f, 0.285072f, + 0.191627f, 0.745949f, 0.752707f, 0.417238f, 0.587475f, 0.127896f, 1.346089f, 0.149638f, -0.116884f, + 0.677536f, 0.034505f, -0.174461f, -0.128073f, 0.290052f, 0.446063f, 0.521620f, 0.122747f, 0.075711f, + 0.025843f, -0.149042f, 0.070983f, 0.597205f, 0.074123f, 0.968585f, -0.163160f, 0.082720f, 0.586964f, + 0.098779f, 0.137334f, 1.757106f, -0.276652f, 1.146234f, 0.434016f, 0.136440f, 0.269671f, 0.128756f, + -0.117812f, 0.533588f, 1.186146f, 0.443822f, 2.062000f, 0.421238f, 0.683523f}; // Test float16, without activation int min_cuda_architecture = 530; @@ -408,20 +411,20 @@ TEST(GroupNormTest, GroupNorm_128) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - // Test float16, with activation - if (HasCudaEnvironment(min_cuda_architecture)) { + // Test float32, with activation + if (HasCudaEnvironment(0)) { OpTester test("GroupNorm", 1, onnxruntime::kMSDomain); test.AddAttribute("epsilon", 1e-05f); test.AddAttribute("groups", 32); test.AddAttribute("activation", 1); - test.AddInput("X", dims, ToFloat16(input_data)); + test.AddInput("X", dims, input_data); test.AddInput("gamma", {C}, gamma_data); test.AddInput("beta", {C}, beta_data); constexpr float rel_error = 0.0f; - constexpr float abs_error = 0.02f; - test.AddOutput("Y", dims, ToFloat16(swish_data), false, rel_error, abs_error); + constexpr float abs_error = 0.01f; + test.AddOutput("Y", dims, swish_data, false, rel_error, abs_error); std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); From a9ebeec37c77053b781959f4f43b0723d7a639f5 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 29 Jan 2023 18:22:09 +0000 Subject: [PATCH 12/27] Fuse Bias and SplitGelu --- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 8 +- .../{split_gelu.cc => bias_split_gelu.cc} | 24 +++-- .../{split_gelu.h => bias_split_gelu.h} | 4 +- .../cuda/diffusion/bias_split_gelu_impl.cu | 87 +++++++++++++++++++ ...lit_gelu_impl.h => bias_split_gelu_impl.h} | 3 +- .../cuda/diffusion/split_gelu_impl.cu | 81 ----------------- .../core/graph/contrib_ops/diffusion_defs.cc | 29 ++++--- .../python/tools/symbolic_shape_infer.py | 11 +-- ...n_splitgelu.py => fusion_biassplitgelu.py} | 47 +++++----- .../tools/transformers/fusion_group_norm.py | 14 +-- .../tools/transformers/onnx_model_unet.py | 6 +- 11 files changed, 176 insertions(+), 138 deletions(-) rename onnxruntime/contrib_ops/cuda/diffusion/{split_gelu.cc => bias_split_gelu.cc} (68%) rename onnxruntime/contrib_ops/cuda/diffusion/{split_gelu.h => bias_split_gelu.h} (82%) create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu rename onnxruntime/contrib_ops/cuda/diffusion/{split_gelu_impl.h => bias_split_gelu_impl.h} (69%) delete mode 100644 onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.cu rename onnxruntime/python/tools/transformers/{fusion_splitgelu.py => fusion_biassplitgelu.py} (63%) diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 4e39ad0efd44f..91e9043978a9f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -19,6 +19,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); @@ -89,8 +91,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SplitGelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SplitGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); @@ -148,6 +148,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -214,8 +216,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc similarity index 68% rename from onnxruntime/contrib_ops/cuda/diffusion/split_gelu.cc rename to onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc index d5b36b2fe990a..adfdca9e972f5 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -2,8 +2,8 @@ // Licensed under the MIT License. #include "core/providers/cuda/cuda_common.h" -#include "contrib_ops/cuda/diffusion/split_gelu.h" -#include "contrib_ops/cuda/diffusion/split_gelu_impl.h" +#include "contrib_ops/cuda/diffusion/bias_split_gelu.h" +#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" namespace onnxruntime { namespace contrib { @@ -11,14 +11,14 @@ namespace cuda { #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ - SplitGelu, \ + BiasSplitGelu, \ kMSDomain, \ 1, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SplitGelu); + BiasSplitGelu); REGISTER_KERNEL_TYPED(MLFloat16); REGISTER_KERNEL_TYPED(float); @@ -26,11 +26,11 @@ REGISTER_KERNEL_TYPED(float); using namespace ONNX_NAMESPACE; template -SplitGelu::SplitGelu(const OpKernelInfo& op_info) : CudaKernel(op_info) { +BiasSplitGelu::BiasSplitGelu(const OpKernelInfo& op_info) : CudaKernel(op_info) { } template -Status SplitGelu::ComputeInternal(OpKernelContext* context) const { +Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const auto& input_dims = input->Shape().GetDims(); @@ -39,6 +39,17 @@ Status SplitGelu::ComputeInternal(OpKernelContext* context) const { "input is expected to have 3 dimensions, got ", input_dims.size()); } + const Tensor* bias = context->Input(0); + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimensions, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of input and bias are not the same"); + } + TensorShapeVector output_shape = input->Shape().AsShapeVector(); output_shape[2] = input_dims[2] / 2; Tensor* output = context->Output(0, output_shape); @@ -48,6 +59,7 @@ Status SplitGelu::ComputeInternal(OpKernelContext* context) const { const int32_t half_hidden_size = static_cast(input_dims[2] / 2); LaunchSplitGeluKernel(Stream(context), grid_size, half_hidden_size, reinterpret_cast(input->Data()), + reinterpret_cast(bias->Data()), reinterpret_cast(output->MutableData())); CUDA_RETURN_IF_ERROR(cudaPeekAtLastError()); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h similarity index 82% rename from onnxruntime/contrib_ops/cuda/diffusion/split_gelu.h rename to onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h index 547eb65b3ad01..feec45600bbce 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.h @@ -12,9 +12,9 @@ namespace cuda { using namespace onnxruntime::cuda; template -class SplitGelu final : public CudaKernel { +class BiasSplitGelu final : public CudaKernel { public: - SplitGelu(const OpKernelInfo& op_kernel_info); + BiasSplitGelu(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu new file mode 100644 index 0000000000000..9399f98e23dff --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The CUDA kernel is modified from SplitGelu plugin of TensorRT 8.5. +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) { + int32_t index_input = blockIdx.x * HHS * 2 + threadIdx.x; + int32_t index_output = blockIdx.x * HHS + threadIdx.x; + int32_t index_bias = threadIdx.x; + +#pragma unroll + for (int32_t i = 0; i < HHS / TPB; ++i) { + auto value_left = static_cast(input[index_input] + bias[index_bias]); + auto value_right = static_cast(input[index_input + HHS] + bias[index_bias + HHS]); + + // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0) + float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); + float result = value_left * gelu_right; + output[index_output] = static_cast(result); + index_input += TPB; + index_output += TPB; + index_bias += TPB; + } + return; +} + +template +void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + T const* input, T const* bias, T* output) { + constexpr int32_t TPB = 256; // thread per block + switch (half_hidden_size) { + case 1280: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 2560: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 5120: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + default: + ORT_NOT_IMPLEMENTED("Not implemented"); + } +} + +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); + +template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + float const* input, float const* bias, float* output); + +template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + half const* input, half const* bias, half* output); +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h similarity index 69% rename from onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.h rename to onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h index d83e59b595b64..aebbbc5a70956 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h @@ -13,7 +13,8 @@ namespace contrib { namespace cuda { template -void LaunchSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, T const* input, T* output); +void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, + T const* input, T const* bias, T* output); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.cu deleted file mode 100644 index dad3ff4c243ea..0000000000000 --- a/onnxruntime/contrib_ops/cuda/diffusion/split_gelu_impl.cu +++ /dev/null @@ -1,81 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// The CUDA kernel is modified from SplitGelu plugin of TensorRT 8.5 -#include -#include -#include -#include "core/providers/cuda/cuda_common.h" -#include "contrib_ops/cuda/diffusion/split_gelu_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -template -__global__ void splitGeluKernel(T const* input, T* output) { - int32_t index_input = blockIdx.x * HHS * 2 + threadIdx.x; - int32_t index_output = blockIdx.x * HHS + threadIdx.x; - -#pragma unroll - for (int32_t i = 0; i < HHS / TPB; ++i) { - auto value_left = static_cast(input[index_input]); - auto value_right = static_cast(input[index_input + HHS]); - - // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / 1.41421356237) + 1.0) - float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); - float result = value_left * gelu_right; - output[index_output] = static_cast(result); - index_input += TPB; - index_output += TPB; - } - return; -} - -template -void LaunchSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, T const* input, T* output) { - constexpr int32_t TPB = 256; // thread per block - switch (half_hidden_size) { - case 1280: - (splitGeluKernel)<<>>(input, output); - break; - case 2560: - (splitGeluKernel)<<>>(input, output); - break; - case 5120: - (splitGeluKernel)<<>>(input, output); - break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); - } -} - -template __global__ void splitGeluKernel(float const*, float*); -template __global__ void splitGeluKernel(float const*, float*); -template __global__ void splitGeluKernel(float const*, float*); -template __global__ void splitGeluKernel(half const*, half*); -template __global__ void splitGeluKernel(half const*, half*); -template __global__ void splitGeluKernel(half const*, half*); - -template void LaunchSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, - float const* input, float* output); - -template void LaunchSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, - half const* input, half* output); -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index ce2ba8ab42f95..ab07c05e17fd4 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -64,37 +64,46 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); -constexpr const char* SplitGelu_ver1_doc = R"DOC( -A fusion used in diffusion model that hidden state is sliced into two parts, one part applied Gelu actication, then these -two parts are multiplied. +constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( +A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left +tensor multiplies the Gelu activation result of right tensor. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( SplitGelu, 1, OpSchema() - .SetDoc(SplitGelu_ver1_doc) + .SetDoc(BiasSplitGelu_ver1_doc) .Input(0, "X", - "Input data tensor. Dimensions are (N, H*W, D), where N is the batch size, H and W are the height and width of the data, and D is hidden dimension", + "Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension", + "T") + .Input(0, + "bias", + "Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor", "T") .Output(0, "Y", - "The output tensor with dimensions (N, H*W, D/2)", + "The output tensor with dimensions (N, S, D/2)", "T") - .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to half tensors.") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X and output Y types to float tensors.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (hasInputShape(ctx, 0)) { + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { auto& input_shape = getInputShape(ctx, 0); if (input_shape.dim().size() != 3) { fail_shape_inference("input shall be 3 dimensions"); } + auto& bias_shape = getInputShape(ctx, 1); + if (bias_shape.dim().size() != 1) { + fail_shape_inference("bias shall be 1 dimension"); + } + TensorShapeProto output_shape; *output_shape.add_dim() = input_shape.dim(0); *output_shape.add_dim() = input_shape.dim(1); - if (input_shape.dim(2).has_dim_value()) { - output_shape.add_dim()->set_dim_value(input_shape.dim(2).dim_value() / 2); + if (bias_shape.dim(0).has_dim_value()) { + output_shape.add_dim()->set_dim_value(bias_shape.dim(0).dim_value() / 2); } else { output_shape.add_dim(); } diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index b0c07b762e092..898225212459f 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -201,7 +201,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, "GroupNorm": self._infer_GroupNorm, - "SplitGelu": self._infer_SplitGelu, + "BiasSplitGelu": self._infer_BiasSplitGelu, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -437,7 +437,7 @@ def _onnx_infer_single_node(self, node): "PythonOp", "MultiHeadAttention", "GroupNorm", - "SplitGelu", + "BiasSplitGelu", ] if not skip_infer: @@ -2063,11 +2063,12 @@ def _infer_SkipLayerNormalization(self, node): def _infer_GroupNorm(self, node): self._propagate_shape_and_type(node) - def _infer_SplitGelu(self, node): + def _infer_BiasSplitGelu(self, node): input_shape = self._get_shape(node, 0) - if input_shape: + bias_shape = self._get_shape(node, 1) + if input_shape and bias_shape: output_shape = input_shape - output_shape[2] = int(input_shape[2] / 2) + output_shape[2] = int(bias_shape[0] / 2) vi = self.known_vi_[node.output[0]] output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) diff --git a/onnxruntime/python/tools/transformers/fusion_splitgelu.py b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py similarity index 63% rename from onnxruntime/python/tools/transformers/fusion_splitgelu.py rename to onnxruntime/python/tools/transformers/fusion_biassplitgelu.py index 6e2e086717aed..1a93f609e2c1e 100644 --- a/onnxruntime/python/tools/transformers/fusion_splitgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py @@ -12,23 +12,23 @@ logger = getLogger(__name__) -class FusionSplitGelu(Fusion): +class FusionBiasSplitGelu(Fusion): def __init__(self, model: OnnxModel): - super().__init__(model, "SplitGelu", "Gelu") + super().__init__(model, "BiasSplitGelu", "Gelu") def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ - [root] --------------------> Slice ---------------> Mul --> - | ^ ^ - | | | - +----------------------------+---Slice --> Gelu---+ - | | ^ - | |-----| - | | | - | Mul Mul - | ^ ^ - v | | - Shape ---> Gather --> Add --> Div --+ + [root] --->Add --------------------> Slice ---------------> Mul --> + | ^ ^ + | | | + +----------------------------+---Slice --> Gelu---+ + | | ^ + | |-----| + | | | + | Mul Mul + | ^ ^ + v | | + Shape ---> Gather --> Add --> Div --+ """ if gelu_node.output[0] not in input_name_to_nodes: return @@ -44,20 +44,23 @@ def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3: return - subgraph_input = slice_before_gelu.input[0] + add_output = slice_before_gelu.input[0] start_index_nodes = self.model.match_parent_path( slice_before_gelu, - ["Div", "Add", "Gather", "Shape"], - [1, 0, 0, 0], + ["Div", "Add", "Gather", "Shape", "Add"], + [1, 0, 0, 0, 0], output_name_to_node, # Mul(1) is optional ) if start_index_nodes is None: start_index_nodes = self.model.match_parent_path( - slice_before_gelu, ["Mul", "Div", "Add", "Gather", "Shape"], [1, 0, 0, 0, 0], output_name_to_node + slice_before_gelu, + ["Mul", "Div", "Add", "Gather", "Shape", "Add"], + [1, 0, 0, 0, 0, 0], + output_name_to_node, ) - if start_index_nodes is None or start_index_nodes[-1].input[0] != subgraph_input: + if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output: return end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node) @@ -87,12 +90,14 @@ def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): if not self.model.is_safe_to_fuse_nodes( subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node ): - logger.info("Skip fuse SplitGelu since it is not safe to fuse the subgraph.") + logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.") return self.nodes_to_remove.extend(subgraph_nodes) - node_name = self.model.create_node_name("SplitGelu", name_prefix="SplitGelu") - fused_node = helper.make_node("SplitGelu", inputs=[subgraph_input], outputs=[subgraph_output], name=node_name) + node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu") + fused_node = helper.make_node( + "BiasSplitGelu", inputs=[start_index_nodes[-1].input[0]], outputs=[subgraph_output], name=node_name + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[node_name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index d676a53492af2..3feb3b8a00289 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -156,9 +156,11 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): # NCHW to NHWC transpose_input = helper.make_node( - "Transpose", [input], [input + "_NHWC"], + "Transpose", + [input], + [input + "_NHWC"], name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"), - perm=[0, 2, 3, 1] + perm=[0, 2, 3, 1], ) new_node = helper.make_node( @@ -175,9 +177,11 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): # NHWC to NCHW transpose_output = helper.make_node( - "Transpose", [output + "_NHWC"], [output], - name = self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"), - perm=[0, 3, 1, 2] + "Transpose", + [output + "_NHWC"], + [output], + name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"), + perm=[0, 3, 1, 2], ) self.nodes_to_add.append(new_node) diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 19657b4345255..3f45efc5cde12 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -7,9 +7,9 @@ from typing import Optional from fusion_attention_unet import FusionAttentionUnet +from fusion_biassplitgelu import FusionBiasSplitGelu from fusion_group_norm import FusionGroupNorm from fusion_options import FusionOptions -from fusion_splitgelu import FusionSplitGelu from onnx import ModelProto from onnx_model_bert import BertOnnxModel @@ -59,8 +59,8 @@ def optimize(self, options: Optional[FusionOptions] = None): group_norm_fusion.apply() if (options is None) or options.enable_splitgelu: - split_gelu_fusion = FusionSplitGelu(self) - split_gelu_fusion.apply() + bias_split_gelu_fusion = FusionBiasSplitGelu(self) + bias_split_gelu_fusion.apply() if (options is None) or options.enable_attention: self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False) From 53a539f1ac3b3cc213de47e14c432612a886d715 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 29 Jan 2023 20:32:47 +0000 Subject: [PATCH 13/27] update bias split gelu --- .../contrib_ops/cuda/diffusion/bias_split_gelu.cc | 10 +++++----- onnxruntime/core/graph/contrib_ops/diffusion_defs.cc | 4 ++-- onnxruntime/core/graph/contrib_ops/ms_opset.h | 4 ++-- .../python/tools/transformers/fusion_biassplitgelu.py | 7 ++++++- .../python/tools/transformers/fusion_options.py | 2 +- .../python/tools/transformers/onnx_model_unet.py | 2 +- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc index adfdca9e972f5..265cbb79e1801 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -39,7 +39,7 @@ Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { "input is expected to have 3 dimensions, got ", input_dims.size()); } - const Tensor* bias = context->Input(0); + const Tensor* bias = context->Input(1); const auto& bias_dims = bias->Shape().GetDims(); if (bias_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -57,10 +57,10 @@ Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const int32_t grid_size = static_cast(input_dims[0] * input_dims[1]); const int32_t half_hidden_size = static_cast(input_dims[2] / 2); - LaunchSplitGeluKernel(Stream(context), grid_size, half_hidden_size, - reinterpret_cast(input->Data()), - reinterpret_cast(bias->Data()), - reinterpret_cast(output->MutableData())); + LaunchBiasSplitGeluKernel(Stream(context), grid_size, half_hidden_size, + reinterpret_cast(input->Data()), + reinterpret_cast(bias->Data()), + reinterpret_cast(output->MutableData())); CUDA_RETURN_IF_ERROR(cudaPeekAtLastError()); return Status::OK(); diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index ab07c05e17fd4..f02076a771408 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -70,14 +70,14 @@ tensor multiplies the Gelu activation result of right tensor. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( - SplitGelu, 1, + BiasSplitGelu, 1, OpSchema() .SetDoc(BiasSplitGelu_ver1_doc) .Input(0, "X", "Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension", "T") - .Input(0, + .Input(1, "bias", "Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor", "T") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 4ac58823c1050..a511d01fe1624 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -49,6 +49,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BeamSearch); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasDropout); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BitmaskBiasDropout); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasGelu); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSplitGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BiasSoftmax); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, BifurcationDetector); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CDist); @@ -88,7 +89,6 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SplitGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TorchEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TransposeMatMul); @@ -137,6 +137,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -178,7 +179,6 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); - fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py index 1a93f609e2c1e..17e57223b96d3 100644 --- a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py @@ -93,10 +93,15 @@ def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.") return + add_node = start_index_nodes[-1] + bias_index, _bias = self.model.get_constant_input(add_node) self.nodes_to_remove.extend(subgraph_nodes) node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu") fused_node = helper.make_node( - "BiasSplitGelu", inputs=[start_index_nodes[-1].input[0]], outputs=[subgraph_output], name=node_name + "BiasSplitGelu", + inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]], + outputs=[subgraph_output], + name=node_name, ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 17306762c118e..16802a59f9a8c 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -37,7 +37,7 @@ def __init__(self, model_type): self.enable_shape_inference = True self.enable_gemm_fast_gelu = False self.enable_group_norm = model_type == "unet" - self.enable_splitgelu = model_type == "unet" + self.enable_bias_splitgelu = model_type == "unet" self.attention_mask_format = AttentionMaskFormat.AttentionMask def use_raw_attention_mask(self, use_raw_mask=True): diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 3f45efc5cde12..5464665b18d0d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -58,7 +58,7 @@ def optimize(self, options: Optional[FusionOptions] = None): group_norm_fusion = FusionGroupNorm(self) group_norm_fusion.apply() - if (options is None) or options.enable_splitgelu: + if (options is None) or options.enable_bias_splitgelu: bias_split_gelu_fusion = FusionBiasSplitGelu(self) bias_split_gelu_fusion.apply() From a0c4957bef7a2a5d5b6e599019c139c156aead2f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 30 Jan 2023 15:40:43 +0000 Subject: [PATCH 14/27] update GroupNorm doc --- onnxruntime/core/graph/contrib_ops/diffusion_defs.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index f02076a771408..14a267357371d 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -46,7 +46,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::INT) .Input(0, "X", - "Input data tensor. Dimensions are (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and width of the data", + "Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data", "T") .Input(1, "gamma", From 82383dcb417f64273fef22d64e44a4912cd38e42 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 31 Jan 2023 07:01:14 +0000 Subject: [PATCH 15/27] packed kv in cross attention --- docs/ContribOperators.md | 98 +++++++++++++- docs/OperatorKernels.md | 5 +- .../cpu/bert/multihead_attention_helper.h | 117 ++++++++++------ .../cpu/transformers/generation_shared.h | 22 ++- .../cuda/bert/add_bias_transpose.cu | 128 +++++++++++------- .../cuda/bert/add_bias_transpose.h | 4 + .../contrib_ops/cuda/bert/attention_impl.cu | 79 ++++++++--- .../cuda/bert/multihead_attention.cc | 6 +- .../core/graph/contrib_ops/bert_defs.cc | 41 ++++-- .../python/tools/symbolic_shape_infer.py | 18 ++- .../transformers/fusion_attention_unet.py | 90 ++++++++++-- .../tools/transformers/fusion_options.py | 49 ++++++- .../stable_diffusion/optimize_pipeline.py | 32 +++-- .../tools/transformers/onnx_model_unet.py | 5 +- .../transformers/test_attention_fusion.py | 21 ++- 15 files changed, 532 insertions(+), 183 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1e6d46963cd21..8cd6d4c9e26f1 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -9,6 +9,7 @@ Do not modify directly.* * com.microsoft.BiasDropout * com.microsoft.BiasGelu * com.microsoft.BiasSoftmax + * com.microsoft.BiasSplitGelu * com.microsoft.BifurcationDetector * com.microsoft.BitmaskBiasDropout * com.microsoft.BitmaskDropout @@ -34,6 +35,7 @@ Do not modify directly.* * com.microsoft.GemmFastGelu * com.microsoft.GreedySearch * com.microsoft.GridSample + * com.microsoft.GroupNorm * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention @@ -590,6 +592,39 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.BiasSplitGelu** + + A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left + tensor multiplies the Gelu activation result of right tensor. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Inputs + +
+
X : T
+
Input tensor. Dimensions are (N, S, D), where N is the batch size, S are image size, and D is hidden dimension
+
bias : T
+
Bias tensor. Dimensions are (D), where D is the same hidden dimension as input tensor
+
+ +#### Outputs + +
+
Y : T
+
The output tensor with dimensions (N, S, D/2)
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X and output Y types to float tensors.
+
+ + ### **com.microsoft.BifurcationDetector** Component for aggressive decoding. Find the bifurcation index of predicted tokens, between source tokens, @@ -1811,6 +1846,61 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GroupNorm** + + Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494). + + This operator transforms input according to + y = gamma * (x - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for Swish
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
+ +#### Outputs + +
+
Y : T
+
The output tensor of the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X and output Y types to float tensors.
+
M : tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.Inverse** #### Version @@ -2132,16 +2222,16 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (4 - 5) +#### Inputs (2 - 5)
query : T
Query with shape (batch_size, sequence_length, hidden_size)
key : T
-
Key with shape (batch_size, kv_sequence_length, hidden_size)
-
value : T
+
Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, v_hidden_size)
-
bias : T
+
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ad571dacb20d7..765d16c0f23cf 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -175,7 +175,8 @@ Do not modify directly.* |||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |LpNormalization|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float)| -|LpPool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float)| +|LpPool|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(float)| +|||[11, 17]|**T** = tensor(float)| |||[2, 10]|**T** = tensor(float)| |MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| @@ -789,6 +790,7 @@ Do not modify directly.* |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BitmaskBiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)| |BitmaskDropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)| |ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -804,6 +806,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| +|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index ce109a83720b9..8c3af05972c95 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -21,11 +21,15 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, int max_threads_per_block) { - // query (Q) : (B, S, D) - // key (K) : (B, L, D) - // value (V) : (B, L, D_v) - // bias (Q/K/V) : (D + D + D_v) - // key_padding_mask (K/V) : (B, L) or (L) + // query (Q) : (B, S, D) + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // bias (Q/K/V) : (D + D + D_v) + // key_padding_mask (K/V) : (B) or (B, L) or None + // When packed kv is used: + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None const auto& query_dims = query->Shape().GetDims(); if (query_dims.size() != 3) { @@ -34,15 +38,50 @@ Status CheckInputs(const T* query, } const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + if (key_dims.size() != 3 && key_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ", key_dims.size()); } + if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ", - bias_dims.size()); + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int hidden_size = static_cast(query_dims[2]); + int head_size = static_cast(hidden_size) / num_heads; + int kv_sequence_length = static_cast(key_dims[1]); + + if (key_dims.size() == 3) { + if (key_dims[2] != query_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + } + } else // if (key_dims.size() == 5) + { + if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); + } + } + + if (bias != nullptr) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ", + bias_dims.size()); + } + + // Currently, bias is not allowed for packed KV. This constraint can be removed later. + // Here we assume that fusion tool will not include bias for packed KV. + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. "); + } } AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; @@ -61,47 +100,39 @@ Status CheckInputs(const T* query, } } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - int64_t batch_size = query_dims[0]; - int64_t sequence_length = query_dims[1]; - int64_t kv_sequence_length = key_dims[1]; - int64_t q_hidden_size = query_dims[2]; - int64_t v_hidden_size = 0; - - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } + int v_hidden_size = hidden_size; + if (value != nullptr) { + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } - if (key_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have same same dim 1 (sequence_length)"); + if (key_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)"); + } + v_hidden_size = static_cast(value_dims[2]); } - v_hidden_size = value_dims[2]; if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); - output_parameters->batch_size = static_cast(batch_size); - output_parameters->sequence_length = static_cast(sequence_length); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; output_parameters->past_sequence_length = 0; - output_parameters->kv_sequence_length = static_cast(kv_sequence_length); - output_parameters->total_sequence_length = static_cast(kv_sequence_length); + output_parameters->kv_sequence_length = kv_sequence_length; + output_parameters->total_sequence_length = kv_sequence_length; output_parameters->max_sequence_length = 0; output_parameters->input_hidden_size = 0; - output_parameters->hidden_size = static_cast(q_hidden_size); - output_parameters->v_hidden_size = static_cast(v_hidden_size); - output_parameters->head_size = static_cast(q_hidden_size) / num_heads; - output_parameters->v_head_size = static_cast(v_hidden_size) / num_heads; + output_parameters->hidden_size = hidden_size; + output_parameters->v_hidden_size = v_hidden_size; + output_parameters->head_size = hidden_size / num_heads; + output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; output_parameters->is_unidirectional = false; output_parameters->past_present_share_buffer = false; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 7b641cbef046a..6b092d3e99f4e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -53,14 +53,14 @@ struct IBeamSearchCpuState { template struct IGreedySearchState { - gsl::span sequences_space; // shape (2, batch_size, max_length) - gsl::span sequence_lengths; // shape (batch_size) - gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. - gsl::span eos_meet; // shape (batch_size) - gsl::span next_token_scores; // shape (batch_size, vocab_size) - gsl::span next_tokens; // shape (batch_size) - gsl::span temp_topk_scores_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1 (GPU only) - gsl::span temp_topk_tokens_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1(GPU only) + gsl::span sequences_space; // shape (2, batch_size, max_length) + gsl::span sequence_lengths; // shape (batch_size) + gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. + gsl::span eos_meet; // shape (batch_size) + gsl::span next_token_scores; // shape (batch_size, vocab_size) + gsl::span next_tokens; // shape (batch_size) + gsl::span temp_topk_scores_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1 (GPU only) + gsl::span temp_topk_tokens_buffer; // shape (batch_size, parts_of_vocab), temp buffer for topk stage 1(GPU only) gsl::span topk_scores_buffer; // shape (batch_size), output buffer for topk stage 2 (GPU only) gsl::span topk_tokens_buffer; // shape (batch_size), output buffer for topk stage 2 (GPU only) }; @@ -163,15 +163,14 @@ struct IGenerationParameters { bool custom_sampling = false; }; - #ifndef NDEBUG -//#define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) +// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) #endif #ifdef DEBUG_GENERATION #define DUMP_TENSOR_LEVEL 2 #else -#define DUMP_TENSOR_LEVEL 0 // change it to 0 if want to disable dumping for code not in generation. +#define DUMP_TENSOR_LEVEL 0 // change it to 0 if want to disable dumping for code not in generation. #endif #if DUMP_TENSOR_LEVEL > 0 @@ -187,7 +186,6 @@ struct IGenerationParameters { #define DUMP_TENSOR_D(...) #endif - class IConsoleDumper { public: IConsoleDumper() : is_enabled_(true) {} diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index b7eebb9d48785..e86736726c224 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -366,6 +366,39 @@ __global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* outp } } +template +__global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output) { + // Format 4 to unpack TRT packed input format for memory efficient attention. + // Input: BxSxNxMxH + // Output: MxBxSxNxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * head_size; + const int NHS = NH * sequence_length; + + int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M; + const int out_offset = n * head_size + s * NH + b * NHS + m * NHS * batch_size; + + const int h = threadIdx.x; + if (h < head_size) { + if (biases != nullptr) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + } else { + output[out_offset + h] = input[in_offset + h]; + } + } +} + template __global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) { // Format 3 for cutlass memory efficient attention @@ -481,7 +514,6 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const } } - template void InvokeAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, @@ -506,7 +538,9 @@ void InvokeAddBiasTranspose( ORT_ENFORCE(total_matrix_count == 3); AddBiasTransposeCutlass<<>>(input, biases, output, v_head_size); } - } else { // format == 0 + } else if (format == 4) { // format == 4 + AddBiasUnpack<<>>(total_matrix_count, input, biases, output); + } else { // format == 0 AddBiasTranspose<<>>(input, biases, output); } } else { @@ -528,6 +562,8 @@ void InvokeAddBiasTranspose( } else { ORT_THROW("AddBiasTranspose (format 3) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size"); } + } else if (format == 4) { // format == 4 + ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block"); } else { // format 0 AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } @@ -551,7 +587,7 @@ void LaunchAddBiasTranspose( InvokeAddBiasTranspose(stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, H, input2, biases2, output2, qkv_add_bias2, H_v, total_matrix_count); - } else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) { + } else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) { const int H = qk_head_size / 2; const int H_v = v_head_size / 2; const half2* input2 = reinterpret_cast(input); @@ -610,7 +646,6 @@ void InvokeAddBiasTransposeTrt( const int batch_size, const int sequence_length, const int num_heads, const int head_size, const T* biases, const T* query, const T* key, const T* value, T* output, bool is_cross_attention, int kv_sequence_length) { - if (!is_cross_attention) { ORT_ENFORCE(sequence_length == kv_sequence_length); constexpr int num_matrices = 3; @@ -696,52 +731,51 @@ void LaunchAddBiasTransposeTrt( } } - template void InvokeAddBias( cudaStream_t stream, const int max_threads_per_block, const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v) { - constexpr int num_matrices = 1; - // Q - { - const dim3 grid(sequence_length, batch_size, num_matrices); - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(query, biases, q); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); - } + constexpr int num_matrices = 1; + // Q + { + const dim3 grid(sequence_length, batch_size, num_matrices); + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(query, biases, q); + } else { + const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); } - // K - { - const dim3 grid(kv_sequence_length, batch_size, num_matrices); - const T* biases_k = biases + num_heads * head_size; + } + // K + { + const dim3 grid(kv_sequence_length, batch_size, num_matrices); + const T* biases_k = biases + num_heads * head_size; - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(key, biases_k, k); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k); - } + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(key, biases_k, k); + } else { + const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, key, biases_k, k); } + } - // V - { - const dim3 grid(kv_sequence_length, batch_size, num_matrices); + // V + { + const dim3 grid(kv_sequence_length, batch_size, num_matrices); - const T* biases_v = biases + 2 * num_heads * head_size; - if (v_head_size * num_heads <= max_threads_per_block) { - const dim3 block(v_head_size, num_heads, 1); - AddBiasTransposeTrt<<>>(value, biases_v, v); - } else { - const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); - AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v); - } + const T* biases_v = biases + 2 * num_heads * head_size; + if (v_head_size * num_heads <= max_threads_per_block) { + const dim3 block(v_head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(value, biases_v, v); + } else { + const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1); + AddBiasTransposeTrtLarge<<>>(v_head_size, value, biases_v, v); } + } } template <> @@ -750,7 +784,7 @@ void LaunchAddBias( const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const float* biases, const float* query, const float* key, const float* value, float* q, float* k, float* v) { -if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { + if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { const int H = head_size / 4; const int H_v = v_head_size / 4; const float4* query2 = reinterpret_cast(query); @@ -761,8 +795,8 @@ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { float4* k2 = reinterpret_cast(k); float4* v2 = reinterpret_cast(v); InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, - biases2, query2, key2, value2, q2, k2, v2); + batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, + biases2, query2, key2, value2, q2, k2, v2); } else if (0 == (head_size & 1) && 0 == (v_head_size & 1)) { const int H = head_size / 2; const int H_v = v_head_size / 2; @@ -774,14 +808,13 @@ if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { float2* k2 = reinterpret_cast(k); float2* v2 = reinterpret_cast(v); InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, - biases2, query2, key2, value2, q2, k2, v2); + batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, + biases2, query2, key2, value2, q2, k2, v2); } else { InvokeAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, num_heads, head_size, v_head_size, - biases, query, key, value, q, k, v); + batch_size, sequence_length, kv_sequence_length, num_heads, head_size, v_head_size, + biases, query, key, value, q, k, v); } - } template <> @@ -790,8 +823,7 @@ void LaunchAddBias( const int batch_size, const int sequence_length, const int kv_sequence_length, const int num_heads, const int head_size, const int v_head_size, const half* biases, const half* query, const half* key, const half* value, half* q, half* k, half* v) { - - if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { + if (0 == (head_size % 4) && 0 == (v_head_size % 4)) { const int H = head_size / 4; const int H_v = v_head_size / 4; const Half4* query2 = reinterpret_cast(query); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index 8cc36637054e7..a2c3265284a4d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -24,6 +24,10 @@ namespace cuda { // format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (batch_size, sequence_length, num_matrices, num_heads, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) +// format 4: (requires qk_head_size = v_head_size) +// input: (batch_size, sequence_length, num_heads, num_matrices, head_size) +// output: (num_matrices, batch_size, sequence_length, num_heads, head_size) + template void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 187f1bb37edc5..7731cb011c8d6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -317,7 +317,34 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, 3); } - } else { // gemm_buffer == nullptr + } else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv + // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. + // CheckInputs verified this constraint. + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_ATTENTION_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + + if (use_memory_efficient_attention) { + // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, k, + true, v_head_size, qkv_add_bias, 2); + DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (data.fused_cross_attention_kernel == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } + } else { // gemm_buffer == nullptr and not packed kv assert(data.query != nullptr && data.key != nullptr && data.value != nullptr && data.bias != nullptr); DUMP_ATTENTION_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); @@ -419,8 +446,7 @@ Status QkvToContext( void* fused_runner = data.fused_runner; // At most one fused kernel is enabled. - assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr) <= 1); + assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); const int batches = batch_size * num_heads; const int size_per_batch_q = sequence_length * qk_head_size; @@ -481,7 +507,7 @@ Status QkvToContext( ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, - use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer + use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer data.gemm_buffer, data.present)); present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; @@ -514,18 +540,26 @@ Status QkvToContext( FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); + // When there is no bias, we can directly use q and packed kv from inputs. TODO: not need qkv in workspace. + void const* query = q; + void const* packed_kv = k; + if (data.value == nullptr && data.bias == nullptr) { + query = data.query; + packed_kv = data.key; + } + run_fused_cross_attention( - q, // Q - k, // packed KV - q_sequence_offset, // cumulated sequence length of Q - kv_sequence_offset, // cumulated sequence length of KV - data.output, // output - cross_attention_kernel, // kernels - batch_size, // batch size - num_heads, // number of heads - qk_head_size, // head size of Q/K/V - sequence_length, // sequence length of Q - kv_sequence_length, // sequence length of KV + query, // Q + packed_kv, // packed KV + q_sequence_offset, // cumulated sequence length of Q + kv_sequence_offset, // cumulated sequence length of KV + data.output, // output + cross_attention_kernel, // kernels + batch_size, // batch size + num_heads, // number of heads + qk_head_size, // head size of Q/K/V + sequence_length, // sequence length of Q + kv_sequence_length, // sequence length of KV stream); DUMP_ATTENTION("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); @@ -570,6 +604,13 @@ Status QkvToContext( assert(data.mask_index == nullptr); assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + const void* query = q; + const void* key = k; + const void* value = v; + if (data.gemm_buffer == nullptr && data.value == nullptr) { // packed KV + query = data.query; + } + MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; @@ -582,9 +623,9 @@ Status QkvToContext( p.causal = parameters.is_unidirectional; p.cu_seqlens_q = nullptr; p.cu_seqlens_k = nullptr; - p.query = q; - p.key = k; - p.value = v; + p.query = query; + p.key = key; + p.value = value; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr; p.stream = stream; @@ -610,7 +651,7 @@ Status QkvToContext( // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; + : parameters.scale; float alpha = use_raw_attention_mask ? one : scale; cublasSetStream(cublas, stream); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index c7e5d34e1691b..93e5e59ed00ae 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -94,6 +94,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && + (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); @@ -111,6 +112,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = !disable_fused_runner_ && fused_cross_attention_kernel == nullptr && + value != nullptr && // fused runner requires packed qkv instead of packed kv (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && @@ -162,10 +164,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = nullptr; - data.bias = reinterpret_cast(bias->Data()); + data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past = nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index b4ad4d64e7ddb..68e3985651123 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -127,32 +127,41 @@ void RestorePaddingTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) - // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, kv_sequence_length, num_heads, 2, head_size) + // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or nullptr // Output 0 has shape (batch_size, sequence_length, v_hidden_size) // Type inference ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); // Shape inference - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + if (hasInputShape(ctx, 0)) { auto& query_shape = getInputShape(ctx, 0); auto& query_dims = query_shape.dim(); if (query_dims.size() != 3) { fail_shape_inference("Inputs 0 (query) shall be 3 dimensions"); } - auto& value_shape = getInputShape(ctx, 2); - auto& value_dims = value_shape.dim(); - if (value_dims.size() != 3) { - fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + if (hasInputShape(ctx, 2)) { + auto& value_shape = getInputShape(ctx, 2); + auto& value_dims = value_shape.dim(); + if (value_dims.size() != 3) { + fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + *output_shape.add_dim() = value_dims[2]; + updateOutputShape(ctx, 0, output_shape); } - ONNX_NAMESPACE::TensorShapeProto output_shape; - *output_shape.add_dim() = query_dims[0]; - *output_shape.add_dim() = query_dims[1]; - *output_shape.add_dim() = value_dims[2]; - updateOutputShape(ctx, 0, output_shape); + if (hasInputShape(ctx, 1)) { + auto& key_shape = getInputShape(ctx, 1); + if (key_shape.dim().size() == 5) { + ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx); + } + } } } @@ -287,16 +296,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Input(1, "key", - "Key with shape (batch_size, kv_sequence_length, hidden_size)", + "Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)", "T") .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, v_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "bias", "Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection", - "T") + "T", + OpSchema::Optional) .Input(4, "key_padding_mask", "Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)", diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 898225212459f..a9c93efd40149 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2001,15 +2001,23 @@ def _infer_BiasGelu(self, node): def _infer_MultiHeadAttention(self, node): # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) - # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + # Without packed KV: + # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) + # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + # With packed KV: + # Input 1 (key) has shape (batch_size, kv_sequence_length, num_heads, 2, head_size) + # Input 2 (value) is nullptr # Output 0 has shape (batch_size, sequence_length, v_hidden_size) query_shape = self._get_shape(node, 0) - value_shape = self._get_shape(node, 2) + key_shape = self._get_shape(node, 1) + assert len(query_shape) == 3 - assert len(query_shape) == 3 and len(value_shape) == 3 + # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. output_shape = query_shape - output_shape[2] = value_shape[2] + if len(key_shape) == 3: + value_shape = self._get_shape(node, 2) + assert len(value_shape) == 3 + output_shape[2] = value_shape[2] output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 41daad4283b4c..682255ce17713 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -19,11 +19,14 @@ class FusionAttentionUnet(Fusion): Fuse Attention subgraph of UNet into one Attention node. """ - def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool): + def __init__( + self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool, enable_packed_kv: bool + ): super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) self.hidden_size = hidden_size self.num_heads = num_heads self.is_cross_attention = is_cross_attention + self.enable_packed_kv = enable_packed_kv # Flags to show warning only once self.num_heads_warning = True @@ -105,7 +108,7 @@ def create_attention_node( if is_self_attention: if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input: logger.debug( - "For self attention, input hidden state for q and k/v shall be different. Got %s, %s, %s", + "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s", q_matmul.input[0], k_matmul.input[0], v_matmul.input[0], @@ -176,8 +179,63 @@ def create_attention_node( ) self.model.add_initializer(weight, self.this_graph_name) - else: + else: # cross attention attention_node_name = self.model.create_node_name("MultiHeadAttention") + if self.enable_packed_kv: + if kw.shape != vw.shape: + return None + + kw_in_size = kw.shape[0] + vw_in_size = vw.shape[0] + assert kw_in_size == vw_in_size + + qw_out_size = qw.shape[1] + kw_out_size = kw.shape[1] + vw_out_size = vw.shape[1] + assert qw_out_size == vw_out_size == vw_out_size + + C = kw_in_size + N = num_heads + H = kw_out_size // N + + # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape + kv_weight = np.dstack([kw.reshape(C, N, H), vw.reshape(C, N, H)]).reshape(C, N * 2 * H) + + matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") + weight = helper.make_tensor( + name=matmul_node_name + "_weight", + data_type=TensorProto.FLOAT, + dims=[kv_weight.shape[0], kv_weight.shape[1]], + vals=kv_weight.flatten().tolist(), + ) + + self.model.add_initializer(weight, self.this_graph_name) + + matmul_node = helper.make_node( + "MatMul", + inputs=[k_matmul.input[0], matmul_node_name + "_weight"], + outputs=[matmul_node_name + "_out"], + name=matmul_node_name, + ) + self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name + + shape_tensor = helper.make_tensor( + name=matmul_node_name + "_reshape_shape", + data_type=TensorProto.INT64, + dims=[5], + vals=[0, 0, N, 2, H], + ) + self.model.add_initializer(shape_tensor, self.this_graph_name) + + reshape_node = helper.make_node( + "Reshape", + inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], + outputs=[k_matmul.output[0]], + name=matmul_node_name + "_reshape", + ) + self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name + self.nodes_to_add.extend([matmul_node, reshape_node]) + self.nodes_to_remove.extend([k_matmul, v_matmul]) # No bias, use zeros qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) @@ -198,12 +256,18 @@ def create_attention_node( attention_node_name + "_qkv_bias", ] else: - attention_inputs = [ - q_matmul.output[0], - k_matmul.output[0], - v_matmul.output[0], - attention_node_name + "_qkv_bias", - ] + if not self.enable_packed_kv: + attention_inputs = [ + q_matmul.output[0], + k_matmul.output[0], + v_matmul.output[0], + attention_node_name + "_qkv_bias", + ] + else: + attention_inputs = [ + q_matmul.output[0], + k_matmul.output[0], + ] attention_node = helper.make_node( "Attention" if is_self_attention else "MultiHeadAttention", @@ -214,6 +278,14 @@ def create_attention_node( attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + counter_name = ( + "Attention (self attention)" + if is_self_attention + else "MultiHeadAttention ({})".format( + "cross attention with packed kv" if self.enable_packed_kv else "cross attention" + ) + ) + self.increase_counter(counter_name) return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 16802a59f9a8c..cdfa2c626fc57 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -6,9 +6,16 @@ class AttentionMaskFormat: + # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance. MaskIndexEnd = 0 + + # For experiment only. Do not use it in production. MaskIndexEndAndStart = 1 + + # Raw attention mask with 0 means padding (or no attention) and 1 otherwise. AttentionMask = 2 + + # No attention mask NoMask = 3 @@ -36,9 +43,17 @@ def __init__(self, model_type): self.enable_shape_inference = True self.enable_gemm_fast_gelu = False + + # Set default to sequence length for BERT model to use fused attention to speed up. + # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. + self.attention_mask_format = ( + AttentionMaskFormat.MaskIndexEnd if model_type == "bert" else AttentionMaskFormat.AttentionMask + ) + + # options for stable diffusion self.enable_group_norm = model_type == "unet" self.enable_bias_splitgelu = model_type == "unet" - self.attention_mask_format = AttentionMaskFormat.AttentionMask + self.enable_packed_kv = model_type == "unet" def use_raw_attention_mask(self, use_raw_mask=True): if use_raw_mask: @@ -76,10 +91,14 @@ def parse(args): options.enable_gemm_fast_gelu = True if args.use_mask_index: options.use_raw_attention_mask(False) + if args.use_raw_attention_mask: + options.use_raw_attention_mask(True) if args.no_attention_mask: options.disable_attention_mask() - if args.enable_group_norm: - options.enable_group_norm = True + if args.disable_group_norm: + options.enable_group_norm = False + if args.disable_packed_kv: + options.enable_packed_kv = False return options @staticmethod @@ -168,10 +187,18 @@ def add_arguments(parser: ArgumentParser): "--use_mask_index", required=False, action="store_true", - help="use mask index instead of raw attention mask in attention operator", + help="use mask index to activate fused attention to speed up. It requires right-side padding!", ) parser.set_defaults(use_mask_index=False) + parser.add_argument( + "--use_raw_attention_mask", + required=False, + action="store_true", + help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.", + ) + parser.set_defaults(use_raw_attention_mask=False) + parser.add_argument( "--no_attention_mask", required=False, @@ -191,9 +218,17 @@ def add_arguments(parser: ArgumentParser): parser.set_defaults(use_multi_head_attention=False) parser.add_argument( - "--enable_group_norm", + "--disable_group_norm", + required=False, + action="store_true", + help="not fuse GroupNorm. Only works for model_type=unet", + ) + parser.set_defaults(disable_group_norm=False) + + parser.add_argument( + "--disable_packed_kv", required=False, action="store_true", - help="fuse GroupNorm. Only works for model_type=unet", + help="not use packed kv in cross attention. Only works for model_type=unet", ) - parser.set_defaults(enable_group_norm=False) + parser.set_defaults(disable_packed_kv=False) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 77de06c509e12..26b41c575b165 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -6,19 +6,14 @@ # This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference. # # Before running this script, you need convert checkpoint to float32 onnx models like the following -# git clone https://github.com/huggingface/diffusers -# cd diffusers -# pip install -e . -# huggingface-cli login -# python scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ../stable-diffusion-v1-5-fp32 -# Or use diffusers packages: # export ONNX_ROOT=./sd_onnx -# pip install diffusers==0.11.1 transformers==4.21.2 +# pip install -r requirements.txt # huggingface-cli login -# wget https://raw.githubusercontent.com/huggingface/diffusers/v0.11.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +# wget https://raw.githubusercontent.com/huggingface/diffusers/v0.12.1/scripts/convert_stable_diffusion_checkpoint_to_onnx.py # python convert_stable_diffusion_checkpoint_to_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path $ONNX_ROOT/stable-diffusion-v1-5-fp32 # python convert_stable_diffusion_checkpoint_to_onnx.py --model_path stabilityai/stable-diffusion-2-1 --output_path $ONNX_ROOT/stable-diffusion-v2-1-fp32 -# +# Note that this script might not be compatible with older or newer version of diffusers/transformers. It is because fusion script need change accordingly when onnx graph is changed. + # Then you can use this script to convert them to float16 like the following: # python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v1-5-fp32 -o $ONNX_ROOT/stable-diffusion-v1-5-fp16 --float16 # python optimize_pipeline.py -i $ONNX_ROOT/stable-diffusion-v2-1-fp32 -o $ONNX_ROOT/stable-diffusion-v2-1-fp16 --float16 @@ -43,6 +38,8 @@ logger = logging.getLogger(__name__) +DEBUG = True + def optimize_stable_diffusion_onnx_pipeline( source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool @@ -60,7 +57,7 @@ def optimize_stable_diffusion_onnx_pipeline( RuntimeError: input onnx model does not exist RuntimeError: output onnx model path existed """ - dirs_with_onnx = ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"] + dirs_with_onnx = ["unet"] if DEBUG else ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"] for name in dirs_with_onnx: onnx_model_path = source_dir / name / "model.onnx" @@ -121,7 +118,20 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): RuntimeError: source path does not exist RuntimeError: output path exists but overwrite is false. """ - extra_dirs = ["scheduler", "tokenizer", "feature_extractor"] + extra_dirs = ( + [ + "vae_encoder", + "vae_decoder", + "text_encoder", + "safety_checker", + "scheduler", + "tokenizer", + "feature_extractor", + ] + if DEBUG + else ["scheduler", "tokenizer", "feature_extractor"] + ) + for name in extra_dirs: source_path = source_dir / name diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 5464665b18d0d..feba717bd8f6f 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -63,10 +63,11 @@ def optimize(self, options: Optional[FusionOptions] = None): bias_split_gelu_fusion.apply() if (options is None) or options.enable_attention: - self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False) + self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False) self_attention_fusion.apply() - cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True) + enable_packed_kv = (options is None) or options.enable_packed_kv + cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv) cross_attention_fusion.apply() if (options is None) or options.enable_skip_layer_norm: diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 74d20295a0a63..657d52cc15a31 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -40,6 +40,7 @@ def test_multi_head_attention_fusion(self): onnx.save(model, model_path) options = FusionOptions("bert") options.use_multi_head_attention = True + options.use_raw_attention_mask(True) optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) self.verify_fusion(optimized_model, "attention_mha.onnx") @@ -49,7 +50,9 @@ def test_attention_fusion(self): dir = "." model_path = os.path.join(dir, "attention.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) self.verify_fusion(optimized_model, "attention_opt.onnx") @@ -64,7 +67,9 @@ def test_attention_fusion_pruned_model(self): dir = "." model_path = os.path.join(dir, "pruned_attention.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) self.verify_fusion(optimized_model, "pruned_attention_opt.onnx") @@ -80,7 +85,9 @@ def test_attention_fusion_reverse_add_order(self): dir = "." model_path = os.path.join(dir, "bert_attention_reverse_add_order.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) # reverse add input order will get same optimized model @@ -96,7 +103,9 @@ def test_attention_fusion_for_varied_qkv_dimensions(self): dir = "." model_path = os.path.join(dir, "attention_with_varied_qkv.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, optimization_options=options) os.remove(model_path) self.verify_fusion(optimized_model, "attention_with_varied_qkv_opt.onnx") @@ -113,7 +122,9 @@ def test_attention_fusion_for_varied_qkv_dimensions_with_wrong_opt_parameters(se onnx.save(model, model_path) # wrong num_heads and hidden_size - optimized_model = optimize_model(model_path, "bert", num_heads=8, hidden_size=8) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, "bert", num_heads=8, hidden_size=8, optimization_options=options) os.remove(model_path) From 966b3e72ce904317c8929cbacad3bc72529da7ec Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 31 Jan 2023 07:35:25 +0000 Subject: [PATCH 16/27] fix pyright warnings --- .../python/tools/symbolic_shape_infer.py | 79 ++++++++++--------- .../transformers/fusion_biassplitgelu.py | 6 +- .../tools/transformers/fusion_group_norm.py | 5 ++ .../python/tools/transformers/fusion_utils.py | 1 - .../python/tools/transformers/onnx_model.py | 2 +- 5 files changed, 50 insertions(+), 43 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index a9c93efd40149..689235b630d94 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1967,34 +1967,35 @@ def _infer_ZipMap(self, node): def _infer_Attention(self, node): shape = self._get_shape(node, 0) shape_bias = self._get_shape(node, 2) - assert len(shape) == 3 and len(shape_bias) == 1 - qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") - if qkv_hidden_sizes_attr is not None: - assert len(qkv_hidden_sizes_attr) == 3 - shape[2] = int(qkv_hidden_sizes_attr[2]) - else: - shape[2] = int(shape_bias[0] / 3) - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) - - if len(node.output) > 1: - # input shape: (batch_size, sequence_length, hidden_size) - # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) - # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) - # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length - input_shape = self._get_shape(node, 0) - past_shape = self._get_shape(node, 4) - mask_shape = self._get_shape(node, 3) - if len(past_shape) == 5: - if len(mask_shape) in [2, 3]: - past_shape[3] = mask_shape[-1] - elif isinstance(input_shape[1], int) and isinstance(past_shape[3], int): - past_shape[3] = input_shape[1] + past_shape[3] - else: - past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" - vi = self.known_vi_[node.output[1]] - vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + if shape and len(shape) == 3 and shape_bias and len(shape_bias) == 1: + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[2] = int(qkv_hidden_sizes_attr[2]) + elif isinstance(shape_bias[0], int): + shape[2] = int(shape_bias[0] / 3) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) + + if len(node.output) > 1: + # input shape: (batch_size, sequence_length, hidden_size) + # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) + # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) + # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length + input_shape = self._get_shape(node, 0) + past_shape = self._get_shape(node, 4) + mask_shape = self._get_shape(node, 3) + if past_shape and len(past_shape) == 5: + if mask_shape and len(mask_shape) in [2, 3]: + past_shape[3] = mask_shape[-1] + elif input_shape and len(input_shape) == 3: + if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): + past_shape[3] = input_shape[1] + past_shape[3] + else: + past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) def _infer_BiasGelu(self, node): self._propagate_shape_and_type(node) @@ -2010,18 +2011,18 @@ def _infer_MultiHeadAttention(self, node): # Output 0 has shape (batch_size, sequence_length, v_hidden_size) query_shape = self._get_shape(node, 0) key_shape = self._get_shape(node, 1) - assert len(query_shape) == 3 + if query_shape is not None and len(query_shape) == 3: - # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. - output_shape = query_shape - if len(key_shape) == 3: - value_shape = self._get_shape(node, 2) - assert len(value_shape) == 3 - output_shape[2] = value_shape[2] + # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. + output_shape = query_shape + if key_shape and len(key_shape) == 3: + value_shape = self._get_shape(node, 2) + if value_shape and len(value_shape) == 3: + output_shape[2] = value_shape[2] - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_FastGelu(self, node): self._propagate_shape_and_type(node) @@ -2074,7 +2075,7 @@ def _infer_GroupNorm(self, node): def _infer_BiasSplitGelu(self, node): input_shape = self._get_shape(node, 0) bias_shape = self._get_shape(node, 1) - if input_shape and bias_shape: + if input_shape and bias_shape and isinstance(bias_shape[0], int): output_shape = input_shape output_shape[2] = int(bias_shape[0] / 2) vi = self.known_vi_[node.output[0]] diff --git a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py index 17e57223b96d3..d4f251c855e48 100644 --- a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, Optional +from typing import Dict from fusion_base import Fusion from onnx import helper @@ -94,7 +94,9 @@ def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): return add_node = start_index_nodes[-1] - bias_index, _bias = self.model.get_constant_input(add_node) + bias_index, _value = self.model.get_constant_input(add_node) + if not isinstance(bias_index, int): + return self.nodes_to_remove.extend(subgraph_nodes) node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu") fused_node = helper.make_node( diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index 3feb3b8a00289..1f9455fd34de0 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -64,10 +64,15 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): return weight = self.model.get_constant_value(weight_input) + if weight is None: + return + if not (len(weight.shape) == 3 and weight.shape[1] == 1 and weight.shape[2] == 1): return bias = self.model.get_constant_value(bias_input) + if weight is None: + return if not (len(bias.shape) == 3 and bias.shape[1] == 1 and bias.shape[2] == 1): return diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index d84c433687b5e..8363f2674cd40 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -41,7 +41,6 @@ def cast_input(self, input_name: str, target_type="int32"): cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) - to_type = -1 if target_type == "int32": to_type = int(TensorProto.INT32) elif target_type == "float32": diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index cfca8fb0578b5..96c22b5894c60 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -979,7 +979,7 @@ def save_model_to_file(self, output_path, use_external_data_format=False, all_te # Note: After the model is saved to another directory with external data, # You need reload the onnx model if you want to read tensor from self.model object. - # It is because the base directory is not updated for self.model object so attemp to read tensor data + # It is because the base directory is not updated for self.model object so attempt to read tensor data # might encounter error since external data cannot be located. OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file) logger.info(f"Model saved to {output_path}") From 4a8583e9f0d8c1f1ce23c03a72c2db04db0727aa Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 31 Jan 2023 19:27:47 +0000 Subject: [PATCH 17/27] Add unit test of bias split gelu --- .../cuda/diffusion/bias_split_gelu_impl.cu | 5 +- .../contrib_ops/bias_split_gelu_op_test.cc | 145 ++++++++++++++++++ 2 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu index 9399f98e23dff..1ad23c691cce9 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -23,6 +23,7 @@ #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" namespace onnxruntime { @@ -37,8 +38,8 @@ __global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) { #pragma unroll for (int32_t i = 0; i < HHS / TPB; ++i) { - auto value_left = static_cast(input[index_input] + bias[index_bias]); - auto value_right = static_cast(input[index_input + HHS] + bias[index_bias + HHS]); + auto value_left = float(input[index_input] + bias[index_bias]); + auto value_right = float(input[index_input + HHS] + bias[index_bias + HHS]); // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0) float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); diff --git a/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc new file mode 100644 index 0000000000000..3fac765d898da --- /dev/null +++ b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +using namespace onnxruntime::test; + +namespace onnxruntime { +namespace test { +namespace bias_split_gelu_test { +std::vector ComputeGelu(const std::vector& input_data) { + std::vector output; + output.reserve(input_data.size()); + + for (size_t i = 0; i < input_data.size(); i++) { + float x = input_data[i]; + float y = x * (0.5f * (1.0f + std::erff(x / 1.41421356237f))); + output.push_back(y); + } + return output; +} + +std::vector AddBias(const std::vector& input_data, const std::vector& bias_data) { + size_t bias_length = bias_data.size(); + + std::vector output; + output.reserve(input_data.size()); + + for (size_t i = 0; i < input_data.size(); i++) { + output.push_back(input_data[i] + bias_data[i % bias_length]); + } + return output; +} + +void Split(const std::vector& input_data, + const std::vector& input_dims, + std::vector& left_half_data, std::vector& right_half_data) { + std::size_t length = input_data.size(); + left_half_data.reserve(length / 2); + right_half_data.reserve(length / 2); + + int64_t index = 0; + for (int64_t i = 0; i < input_dims[0]; i++) { + for (int64_t j = 0; j < input_dims[1]; j++) { + for (int64_t k = 0; k < input_dims[2]; k++, index++) { + if (k < input_dims[2] / 2) { + left_half_data.push_back(input_data[index]); + } else { + right_half_data.push_back(input_data[index]); + } + } + } + } +} + +std::vector GetExpectedResult(const std::vector& input_data, + const std::vector& input_dims, + const std::vector& bias_data) { + std::vector add_bias_data = AddBias(input_data, bias_data); + std::vector left_half_data; + std::vector right_half_data; + Split(add_bias_data, input_dims, left_half_data, right_half_data); + std::vector right_gelu_data = ComputeGelu(right_half_data); + + std::vector output_data; + output_data.reserve(left_half_data.size()); + for (std::size_t i = 0; i < left_half_data.size(); i++) { + output_data.push_back(left_half_data[i] * right_gelu_data[i]); + } + return output_data; +} +} // namespace bias_split_gelu_test + +#if defined(USE_CUDA) // The operator has only CUDA implementation right now + +static void RunBiasSplitGeluGpuTest(const std::vector& input_data, + const std::vector& bias_data, + const std::vector& output_data, + const std::vector& input_dims, + const std::vector& bias_dims, + const std::vector& output_dims, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + if (!HasCudaEnvironment(min_cuda_architecture)) { + return; + } + + OpTester tester("BiasSplitGelu", 1, onnxruntime::kMSDomain); + + if (use_float16) { + tester.AddInput("X", input_dims, ToFloat16(input_data)); + tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); + tester.AddOutput("Y", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("X", input_dims, input_data); + tester.AddInput("bias", bias_dims, bias_data); + tester.AddOutput("Y", output_dims, output_data); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +static void RunBiasSplitGeluTest(int64_t batch_size, int64_t sequence_length, int64_t hidden_size) { + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector bias_dims = {hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size / 2}; + + RandomValueGenerator random{}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 0.3f); + std::vector bias_data = random.Gaussian(bias_dims, 0.0f, 0.3f); + std::vector output_data = bias_split_gelu_test::GetExpectedResult(input_data, input_dims, bias_data); + + RunBiasSplitGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims); +} + +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_2560) { + constexpr int64_t batch_size = 2; + constexpr int64_t sequence_length = 5; + constexpr int64_t hidden_size = 2560; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_5120) { + constexpr int64_t batch_size = 2; + constexpr int64_t sequence_length = 1; + constexpr int64_t hidden_size = 5120; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_10240) { + constexpr int64_t batch_size = 1; + constexpr int64_t sequence_length = 2; + constexpr int64_t hidden_size = 10240; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + +#endif + +} // namespace test +} // namespace onnxruntime From 982663a04341009a6b9bb6fa6ab25b9db33d4f33 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 31 Jan 2023 19:44:39 +0000 Subject: [PATCH 18/27] fix typo --- .../python/tools/transformers/fusion_attention_unet.py | 8 ++++---- .../python/tools/transformers/fusion_group_norm.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 682255ce17713..a148b155caaa4 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -153,7 +153,7 @@ def create_attention_node( kw_in_size = kw.shape[0] vw_in_size = vw.shape[0] - assert qw_in_size == kw_in_size == vw_in_size + assert qw_in_size == kw_in_size and kw_in_size == vw_in_size if hidden_size > 0 and hidden_size != qw_in_size: raise ValueError( @@ -192,7 +192,7 @@ def create_attention_node( qw_out_size = qw.shape[1] kw_out_size = kw.shape[1] vw_out_size = vw.shape[1] - assert qw_out_size == vw_out_size == vw_out_size + assert qw_out_size == vw_out_size and kw_out_size == vw_out_size C = kw_in_size N = num_heads @@ -330,11 +330,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0]) if qk_nodes is not None: - (softmax_qk, mul_qk, matmul_qk) = qk_nodes + (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes else: qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) if qk_nodes is not None: - (softmax_qk, add_zero, mul_qk, matmul_qk) = qk_nodes + (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes else: logger.debug("fuse_attention: failed to match qk path") return diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index 1f9455fd34de0..95c58aa8eebf8 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -71,7 +71,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): return bias = self.model.get_constant_value(bias_input) - if weight is None: + if bias is None: return if not (len(bias.shape) == 3 and bias.shape[1] == 1 and bias.shape[2] == 1): return From 73045bbe4542ebc271140e8d99a4161ee5df46d0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 31 Jan 2023 21:23:37 +0000 Subject: [PATCH 19/27] fix code scanning warnings --- .../python/tools/transformers/fusion_biassplitgelu.py | 1 - onnxruntime/python/tools/transformers/fusion_group_norm.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py index d4f251c855e48..106d3de25d39d 100644 --- a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py @@ -108,4 +108,3 @@ def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[node_name] = self.this_graph_name - return True diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index 95c58aa8eebf8..f75f2a5cdcca6 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -76,8 +76,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): if not (len(bias.shape) == 3 and bias.shape[1] == 1 and bias.shape[2] == 1): return - weight_elements = np.prod(weight.shape) - bias_elements = np.prod(bias.shape) + weight_elements = int(np.prod(weight.shape)) + bias_elements = int(np.prod(bias.shape)) if weight_elements != bias_elements: return @@ -196,4 +196,3 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name - return True From 86d5795090c9419c3a839cdff4012d1f4c62bda6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 31 Jan 2023 22:17:04 +0000 Subject: [PATCH 20/27] fix code scanning warnings --- .../python/tools/transformers/fusion_attention_unet.py | 10 +++++----- .../python/tools/transformers/fusion_group_norm.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index a148b155caaa4..0441ce494d560 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -194,12 +194,12 @@ def create_attention_node( vw_out_size = vw.shape[1] assert qw_out_size == vw_out_size and kw_out_size == vw_out_size - C = kw_in_size - N = num_heads - H = kw_out_size // N + c = kw_in_size + n = num_heads + h = kw_out_size // num_heads # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape - kv_weight = np.dstack([kw.reshape(C, N, H), vw.reshape(C, N, H)]).reshape(C, N * 2 * H) + kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h) matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") weight = helper.make_tensor( @@ -223,7 +223,7 @@ def create_attention_node( name=matmul_node_name + "_reshape_shape", data_type=TensorProto.INT64, dims=[5], - vals=[0, 0, N, 2, H], + vals=[0, 0, n, 2, h], ) self.model.add_initializer(shape_tensor, self.this_graph_name) diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index f75f2a5cdcca6..a0a4d7c16de0b 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -94,7 +94,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): and instance_norm_scale.shape == instance_norm_bias.shape and instance_norm_scale.shape[0] == 32 ): - logger.info(f"InstanceNormalization groups={instance_norm_scale.shape[0]}") + logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0]) return if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): @@ -105,7 +105,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") if weight_elements not in [320, 640, 960, 1280, 1920, 2560] + [128, 256, 512]: - logger.info(f"GroupNorm channels={weight_elements}") + logger.info("GroupNorm channels=%d", weight_elements) gamma = helper.make_tensor( name=group_norm_name + "_gamma", From efa6d4f4d53c6e191f8cb513970044aba7c7102e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 31 Jan 2023 23:51:37 +0000 Subject: [PATCH 21/27] address review feedback --- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 6 +- .../contrib_ops/cuda/diffusion/group_norm.cc | 83 ++++++++++++------- .../contrib_ops/cuda/diffusion/group_norm.h | 1 - 3 files changed, 53 insertions(+), 37 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 91e9043978a9f..f01e1740e5d0c 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -73,8 +73,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupNorm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GroupNorm); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); @@ -198,8 +197,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 088b0c9be5a05..36a2bd11257d6 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -9,24 +9,51 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GroupNorm, \ - kMSDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - GroupNorm); - -REGISTER_KERNEL_TYPED(MLFloat16); -REGISTER_KERNEL_TYPED(float); +#define GROUP_NORM_TYPES float, MLFloat16 + +ONNX_OPERATOR_KERNEL_EX( + GroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); using namespace ONNX_NAMESPACE; +namespace { template -GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { +struct DispatchGroupNorm { + Status operator()(cudaStream_t stream, + Tensor* output, + const Tensor* input, + const Tensor* gamma, + const Tensor* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_swish_activation) { + typedef typename ToCudaType::MappedType CudaT; + return LaunchGroupNormKernel( + stream, + reinterpret_cast(output->MutableData()), + reinterpret_cast(input->Data()), + gamma->Data(), + beta->Data(), + workspace, + epsilon, + batch_size, + num_channels, + height, + width, + num_groups, + use_swish_activation); + } +}; + +} // namespace + +GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -41,8 +68,7 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { use_swish_activation_ = (activation == 1); } -template -Status GroupNorm::ComputeInternal(OpKernelContext* context) const { +Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); const Tensor* beta = context->Input(2); @@ -87,22 +113,15 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); - typedef typename ToCudaType::MappedType CudaT; - - return LaunchGroupNormKernel( - Stream(context), - reinterpret_cast(output->MutableData()), - reinterpret_cast(input->Data()), - gamma->Data(), - beta->Data(), - reinterpret_cast(workspace.get()), - epsilon_, - batch_size, - num_channels, - height, - width, - num_groups_, - use_swish_activation_); + utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); + return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + epsilon_, + batch_size, + num_channels, + height, + width, + num_groups_, + use_swish_activation_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 099c083084527..8578a1642198f 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -11,7 +11,6 @@ namespace cuda { using namespace onnxruntime::cuda; -template class GroupNorm final : public CudaKernel { public: GroupNorm(const OpKernelInfo& op_kernel_info); From 7a75ce18a9db3c1b92830d6c272f89bc89b9945f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 1 Feb 2023 19:59:22 +0000 Subject: [PATCH 22/27] Add NhwcConv --- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 + .../contrib_ops/cuda/diffusion/nhwc_conv.cc | 31 +++ onnxruntime/contrib_ops/cuda/fused_conv.cc | 6 +- .../core/providers/cpu/nn/conv_attributes.h | 20 +- .../core/providers/cuda/cudnn_common.cc | 18 +- .../core/providers/cuda/cudnn_common.h | 5 + onnxruntime/core/providers/cuda/nn/conv.cc | 101 ++++++--- onnxruntime/core/providers/cuda/nn/conv.h | 4 +- .../test/contrib_ops/nhwc_conv_op_test.cc | 207 ++++++++++++++++++ 9 files changed, 349 insertions(+), 47 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc create mode 100644 onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index f01e1740e5d0c..a239e528af148 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -74,6 +74,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); @@ -198,6 +200,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc new file mode 100644 index 0000000000000..79f0a18ba515f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/nhwc_conv.cc @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/span_utils.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/nn/conv.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + NhwcConv, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Conv); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index 39c3bb282d912..48881ddca4063 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -9,10 +9,10 @@ namespace contrib { namespace cuda { template -class FusedConv : public onnxruntime::cuda::Conv { +class FusedConv : public onnxruntime::cuda::Conv { public: - using Base = onnxruntime::cuda::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { + using Base = onnxruntime::cuda::Conv; + FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { std::string activation; if (info.GetAttr("activation", &activation) == Status::OK() && MapMode(activation) == Status::OK() && diff --git a/onnxruntime/core/providers/cpu/nn/conv_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_attributes.h index 51a1e7acafe11..b31030acc52c1 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_attributes.h @@ -73,7 +73,7 @@ struct ConvAttributes { ~ConvAttributes() = default; - Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape) const { + Status ComputeKernelShape(const TensorShape& weight_shape, TensorShapeVector& kernel_shape, bool weight_channels_last = false) const { if (kernel_shape_specified) { kernel_shape = kernel_shape_; if (kernel_shape.size() + 2 != weight_shape.NumDimensions()) { @@ -82,15 +82,20 @@ struct ConvAttributes { " W: ", weight_shape.ToString().c_str()); } for (size_t i = 0; i < kernel_shape.size(); ++i) { - if (kernel_shape[i] != weight_shape[i + 2]) { + if (kernel_shape[i] != weight_shape[i + (weight_channels_last ? 1 : 2)]) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", weight_shape.ToString().c_str()); + " W: ", weight_shape.ToString().c_str(), + " channels_last: ", weight_channels_last); } } } else { auto weight_dims = weight_shape.GetDims(); - kernel_shape.assign(weight_dims.begin() + 2, weight_dims.end()); + if (weight_channels_last) { + kernel_shape.assign(weight_dims.begin() + 1, weight_dims.end() - 1); + } else { + kernel_shape.assign(weight_dims.begin() + 2, weight_dims.end()); + } } return Status::OK(); @@ -98,7 +103,8 @@ struct ConvAttributes { Status ValidateInputShape(const TensorShape& input_shape, const TensorShape& weight_shape, - bool channels_last = false) const { + bool input_channels_last = false, + bool weight_channels_last = false) const { if (input_shape.NumDimensions() != weight_shape.NumDimensions()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "X num_dims does not match W num_dims.", " X: ", input_shape.ToString().c_str(), @@ -106,9 +112,9 @@ struct ConvAttributes { } const int64_t M = weight_shape[0]; - const int64_t C = channels_last ? input_shape.GetDims().back() : input_shape[1]; + const int64_t C = input_channels_last ? input_shape.GetDims().back() : input_shape[1]; - if (C != weight_shape[1] * group) { + if (C != (weight_channels_last ? weight_shape.GetDims().back() : weight_shape[1]) * group) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input channels C is not equal to kernel channels * group.", " C: ", C, " kernel channels: ", weight_shape[1], diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index d62a651880a85..6f53bfabd75ed 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -42,6 +42,12 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat return Status::OK(); } +Status CudnnTensor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int64_t n, int64_t c, int64_t h, int64_t w) { + ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); + CUDNN_RETURN_IF_ERROR(cudnnSetTensor4dDescriptor(tensor_, format, dataType, n, c, h, w)); + return Status::OK(); +} + Status CudnnTensor::Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); CUDNN_RETURN_IF_ERROR(cudnnDeriveBNTensorDescriptor(tensor_, x_desc, mode)); @@ -113,15 +119,23 @@ Status CudnnFilterDescriptor::Set(gsl::span filter_dims, cudnnDat return Status::OK(); } +Status CudnnFilterDescriptor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int64_t k, int64_t c, int64_t h, int64_t w) { + if (!desc_) + CUDNN_RETURN_IF_ERROR(cudnnCreateFilterDescriptor(&desc_)); + + CUDNN_RETURN_IF_ERROR(cudnnSetFilter4dDescriptor(desc_, dataType, format, k, c, h, w)); + return Status::OK(); +} + template cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN engine currently supports only single/double/half/int8/uint8 precision data types. Got:", - typeid(ElemType).name()); + typeid(ElemType).name()); // Not reachable but GCC complains return CUDNN_DATA_FLOAT; } -template<> +template <> cudnnDataType_t CudnnTensor::GetDataType() { return CUDNN_DATA_FLOAT; } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index f104373b9413a..ebe352bcf1b7e 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -18,6 +18,8 @@ class CudnnTensor final { Status Set(gsl::span input_dims, cudnnDataType_t dataType); Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode); + // Set 4D tensor format (for NHWC) + Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int64_t n, int64_t c, int64_t h, int64_t w); operator cudnnTensorDescriptor_t() const { return tensor_; } @@ -58,6 +60,9 @@ class CudnnFilterDescriptor final { Status Set(gsl::span filter_dims, cudnnDataType_t data_typ); + // Set 4D filter where k is output channels, c is input channels, h and w is rows and columns per filter. + Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int64_t k, int64_t c, int64_t h, int64_t w); + operator cudnnFilterDescriptor_t() const { return desc_; } private: diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index f1590bc51388d..2ccb0930caa38 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -20,7 +20,7 @@ namespace cuda { T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ + Conv); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Conv, \ kOnnxDomain, \ @@ -28,14 +28,14 @@ namespace cuda { T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); + Conv); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) -template -const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { +template +const cudnnConvolutionFwdAlgo_t Conv::kAllAlgos[] = { CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_FFT, CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, @@ -86,8 +86,8 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, return SliceCuda::Impl(stream, input_data, input_dims, output_data, compute_metadata, element_size); } -template -Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { +template +Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { //set X const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -125,48 +125,61 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const s_.cached_benchmark_results.clear(); } - const int64_t N = X->Shape()[0]; - const int64_t M = W->Shape()[0]; - - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); + constexpr bool channels_last = NHWC; + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); - auto rank = kernel_shape.size(); + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape, channels_last)); + + const size_t kernel_rank = kernel_shape.size(); + ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { - pads.resize(rank * 2, 0); + pads.resize(kernel_rank * 2, 0); } TensorShapeVector dilations(conv_attrs_.dilations); if (dilations.empty()) { - dilations.resize(rank, 1); + dilations.resize(kernel_rank, 1); } TensorShapeVector strides(conv_attrs_.strides); if (strides.empty()) { - strides.resize(rank, 1); + strides.resize(kernel_rank, 1); } TensorShapeVector y_dims; - y_dims.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' - y_dims.insert(y_dims.begin(), {N, M}); + y_dims.reserve(2 + kernel_rank); // add 2 to account for 'N' and 'C' - TensorShapeVector y_dims_with_adjusted_pads; - y_dims_with_adjusted_pads.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' - y_dims_with_adjusted_pads.insert(y_dims_with_adjusted_pads.begin(), {N, M}); + const int64_t N = X->Shape()[0]; + const int64_t M = W->Shape()[0]; + if (channels_last) { + y_dims.push_back(N); + } else { + y_dims.insert(y_dims.begin(), {N, M}); + } bool post_slicing_required = false; TensorShapeVector slice_starts; - slice_starts.reserve(rank); + slice_starts.reserve(kernel_rank); TensorShapeVector slice_ends; - slice_ends.reserve(rank); + slice_ends.reserve(kernel_rank); TensorShapeVector slice_axes; - slice_axes.reserve(rank); + slice_axes.reserve(kernel_rank); - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(x_shape.Slice(2), kernel_shape, + const size_t spatial_dim_start = channels_last ? 1 : 2; + const size_t spatial_dim_end = spatial_dim_start + kernel_rank; + TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); + + TensorShapeVector y_dims_with_adjusted_pads(y_dims); + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(spatial_shape, kernel_shape, strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, post_slicing_required, slice_starts, slice_ends, slice_axes)); + if (channels_last) { + y_dims.push_back(M); + y_dims_with_adjusted_pads.push_back(M); + } + ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); s_.y_dims = gsl::make_span(y_dims); s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; @@ -190,7 +203,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; - if (rank < 2) { + if (kernel_rank < 2) { // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] // especially for EXHAUSTIVE algo search which may result in a better algo selection. // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to @@ -203,7 +216,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); w_dims.insert(w_dims.begin() + 2, 1); - pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.begin(), 0); kernel_shape.insert(kernel_shape.begin(), 1); strides.insert(strides.begin(), 1); @@ -212,7 +225,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const x_dims_cudnn.push_back(1); y_dims_cudnn.push_back(1); w_dims.push_back(1); - pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.end(), 0); kernel_shape.push_back(1); strides.push_back(1); @@ -220,16 +233,30 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const } } - if (w_dims_changed) - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + if (w_dims_changed) { + if (!channels_last) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), w_dims[0], w_dims[3], w_dims[1], w_dims[2])); + } + } // We must delay returning early until here so that the weight dims have been cached properly if (s_.Y->Shape().Size() == 0) { return Status::OK(); } - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + if (channels_last) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + x_dims_cudnn[0], x_dims_cudnn[3], x_dims_cudnn[1], x_dims_cudnn[2])); + + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + y_dims_cudnn[0], y_dims_cudnn[3], y_dims_cudnn[1], y_dims_cudnn[2])); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + } + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); @@ -331,8 +358,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const return Status::OK(); } -template -Status Conv::ComputeInternal(OpKernelContext* context) const { +template +Status Conv::ComputeInternal(OpKernelContext* context) const { std::lock_guard lock(s_.mutex); ORT_RETURN_IF_ERROR(UpdateState(context)); if (s_.Y->Shape().Size() == 0) { @@ -367,7 +394,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { s_.slice_ends, s_.slice_axes, s_.element_size)); } return Status::OK(); -} // namespace cuda +} CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() : desc_(nullptr) { } @@ -424,5 +451,11 @@ Status CudnnConvolutionDescriptor::Set( return Status::OK(); } +#ifndef DISABLE_CONTRIB_OPS +// template instantiation for NhwcConv +template class Conv; +template class Conv; +#endif + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index ae179de0070b0..07825b93204ca 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -177,7 +177,9 @@ enum : size_t { AlgoSearchWorkspaceSize = 32 * 1024 * 1024, }; -template +// ONNX Conv operator uses NCHW format for input, weights and output. +// NhwcConv contrib ops uses NHWC format: last dimension of input, weights and output are channels. +template class Conv : public CudaKernel { public: using CudaT = typename ToCudaType::MappedType; diff --git a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc new file mode 100644 index 0000000000000..724232ec0560d --- /dev/null +++ b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +using namespace std; +namespace onnxruntime { +namespace test { + +namespace { + +struct NhwcConvOpAndTestAttributes { + string auto_pad; + vector dilations; + int64_t group; + vector kernel_shape; + vector pads; + vector strides; + std::unordered_set excluded_providers; +}; + +void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes, + const vector>& inputs, + const vector>& input_shapes, + const std::initializer_list& expected_output, + const vector& expected_output_shape, + bool weight_is_initializer = false, + OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, + const std::string& err_str = "") { + OpTester test("NhwcConv", 1, onnxruntime::kMSDomain); + test.AddAttribute("group", attributes.group); + test.AddAttribute("kernel_shape", attributes.kernel_shape); + + if (!attributes.dilations.empty()) { + test.AddAttribute("dilations", attributes.dilations); + } + + // Only one of pads / auto_pad can be present + if (!attributes.pads.empty()) { + test.AddAttribute("pads", attributes.pads); + } else { + test.AddAttribute("auto_pad", attributes.auto_pad); + } + + if (!attributes.strides.empty()) { + test.AddAttribute("strides", attributes.strides); + } + + ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); + const char* szNames[] = {"X", "W", "B"}; + test.AddInput(szNames[0], input_shapes[0], inputs[0]); + test.AddInput(szNames[1], input_shapes[1], inputs[1], weight_is_initializer); + if (inputs.size() == 3) + test.AddInput(szNames[2], input_shapes[2], inputs[2]); + + test.AddOutput("Y", expected_output_shape, expected_output); + + std::unordered_set excluded_providers(attributes.excluded_providers); + // Disable TensorRT because weight as input is not supported + excluded_providers.insert(kTensorrtExecutionProvider); + + test.Run(expect_result, err_str, excluded_providers); +} + +} // namespace + +TEST(NhwcConvTest, Conv2D_2) { + NhwcConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + 0.45246148109436035f, 0.15498268604278564f, 0.11199361085891724f, -0.39421093463897705f, + 0.2626858949661255f, 0.13414543867111206f, -0.27184486389160156f, -0.43028733134269714f, + -0.26825493574142456f, 0.3893144130706787f, -0.13631996512413025f, -0.009590476751327515f, + -0.48771554231643677f, -0.25256502628326416f, -0.2812897562980652f, 0.4043201804161072f, + 0.07795023918151855f, 0.326981782913208f, 0.13114392757415771f, -0.4416425824165344f, + 0.12446999549865723f, 0.36739975214004517f, 0.1698915958404541f, 0.2008744478225708f, + 0.23339951038360596f, 0.38613730669021606f, 0.11117297410964966f, 0.3877097964286804f, + 0.20812749862670898f, -0.34297940135002136f, -0.029246658086776733f, -0.20483523607254028f, + -0.19244328141212463f, -0.11104947328567505f, -0.32830488681793213f, -0.01800677180290222f, + 0.3618946671485901f, -0.40949052572250366f, -0.18248388171195984f, -0.3349453806877136f, + -0.34091079235076904f, 0.006497859954833984f, 0.4537564516067505f, 0.08006560802459717f, + -0.14788749814033508f, 0.034442365169525146f, -0.33322954177856445f, 0.06049239635467529f, + 0.42619407176971436f}; + vector X_shape = {1, 7, 7, 1}; + vector W = {-0.4406261742115021f}; + vector W_shape = {1, 1, 1, 1}; + vector Y_shape = {1, 7, 7, 1}; + auto expected_vals = { + -0.19936637580394745f, -0.06828942894935608f, -0.04934731498360634f, 0.17369966208934784f, + -0.11574628204107285f, -0.05910799279808998f, 0.1197819635272026f, 0.18959586322307587f, + 0.1182001456618309f, -0.17154212296009064f, 0.06006614491343498f, 0.0042258151806890965f, + 0.21490024030208588f, 0.11128675937652588f, 0.12394362688064575f, -0.17815405130386353f, + -0.034346915781497955f, -0.14407673478126526f, -0.05778544768691063f, 0.19459928572177887f, + -0.05484473705291748f, -0.16188594698905945f, -0.07485868036746979f, -0.08851054310798645f, + -0.10284193605184555f, -0.17014220356941223f, -0.04898572340607643f, -0.17083507776260376f, + -0.09170642495155334f, 0.1511256992816925f, 0.012886842712759972f, 0.09025576710700989f, + 0.08479554951190948f, 0.0489313043653965f, 0.14465972781181335f, 0.007934254594147205f, + -0.15946026146411896f, 0.1804322451353073f, 0.08040717244148254f, 0.1475857049226761f, + 0.15021422505378723f, -0.0028631272725760937f, -0.19993697106838226f, -0.03527900204062462f, + 0.06516310572624207f, -0.015176207758486271f, 0.14682966470718384f, -0.02665453404188156f, + -0.18779225647449493f}; + TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(NhwcConvTest, Conv2D_Bias_1) { + NhwcConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; + vector X_shape = {1, 3, 3, 1}; + vector W = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + vector W_shape = {2, 2, 2, 1}; + vector Y_shape = {1, 2, 2, 2}; + vector B = {1.0f, -1.0f}; + vector B_shape = {2}; + auto expected_vals = {13.0f, 11.0f, 17.0f, 15.0f, 25.0f, 23.0f, 29.0f, 27.0f}; + + TestNhwcConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestNhwcConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + +TEST(NhwcConvTest, Conv2D_AutoPad1) { + NhwcConvOpAndTestAttributes attrs = { + "SAME_UPPER", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + {}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = vector(25, 1.0f); + vector X_shape = {1, 5, 5, 1}; + vector W = {0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f}; + + vector W_shape = {1, 3, 3, 1}; + vector Y_shape = {1, 5, 5, 1}; + auto expected_vals = {24.0f, 33.0f, 33.0f, 33.0f, 20.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 12.0f, 15.0f, 15.0f, 15.0f, 8.0f}; + TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(NhwcConvTest, Conv2D_AutoPad2) { + NhwcConvOpAndTestAttributes attrs = { + "SAME_LOWER", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + {}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = {1.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f}; + vector X_shape = {1, 5, 5, 1}; + vector W = {0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f}; + + vector W_shape = {1, 3, 3, 1}; + vector Y_shape = {1, 5, 5, 1}; + auto expected_vals = {11.0f, 22.0f, 11.0f, 22.0f, 11.0f, + 12.0f, 24.0f, 12.0f, 24.0f, 12.0f, + 12.0f, 24.0f, 12.0f, 24.0f, 12.0f, + 12.0f, 24.0f, 12.0f, 24.0f, 12.0f, + 5.0f, 10.0f, 5.0f, 10.0f, 5.0f}; + TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +} // namespace test +} // namespace onnxruntime From f4d41033af81f783e45c21398af95c31be2f49b3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 1 Feb 2023 20:00:11 +0000 Subject: [PATCH 23/27] fix training api build error --- .../cuda/diffusion/bias_split_gelu_impl.cu | 13 +++++++------ .../cuda/diffusion/bias_split_gelu_impl.h | 2 -- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu index 1ad23c691cce9..3cb95dad26b36 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -19,10 +19,7 @@ * limitations under the License. */ -#include -#include #include -#include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" @@ -38,9 +35,13 @@ __global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) { #pragma unroll for (int32_t i = 0; i < HHS / TPB; ++i) { - auto value_left = float(input[index_input] + bias[index_bias]); - auto value_right = float(input[index_input + HHS] + bias[index_bias + HHS]); - +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + auto value_left = (float)(input[index_input] + bias[index_bias]); + auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]); +#else + auto value_left = (float)(input[index_input]) + (float)(bias[index_bias]); + auto value_right = (float)(input[index_input + HHS]) + (float)(bias[index_bias + HHS]); +#endif // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0) float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); float result = value_left * gelu_right; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h index aebbbc5a70956..a04201bd12e3c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.h @@ -4,9 +4,7 @@ #pragma once #include "core/common/common.h" #include "core/common/status.h" -#include #include -#include namespace onnxruntime { namespace contrib { From 55a746802eeeebac196ed4aa0b299a0fcf95489f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 1 Feb 2023 22:11:02 +0000 Subject: [PATCH 24/27] Add float16 test --- .../test/contrib_ops/nhwc_conv_op_test.cc | 116 ++++++++++-------- 1 file changed, 66 insertions(+), 50 deletions(-) diff --git a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc index 724232ec0560d..6cffaa4d57bf4 100644 --- a/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc +++ b/onnxruntime/test/contrib_ops/nhwc_conv_op_test.cc @@ -3,6 +3,9 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" + using namespace std; namespace onnxruntime { namespace test { @@ -24,42 +27,67 @@ void TestNhwcConvOp(const NhwcConvOpAndTestAttributes& attributes, const vector>& input_shapes, const std::initializer_list& expected_output, const vector& expected_output_shape, - bool weight_is_initializer = false, - OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, - const std::string& err_str = "") { - OpTester test("NhwcConv", 1, onnxruntime::kMSDomain); - test.AddAttribute("group", attributes.group); - test.AddAttribute("kernel_shape", attributes.kernel_shape); - - if (!attributes.dilations.empty()) { - test.AddAttribute("dilations", attributes.dilations); - } - - // Only one of pads / auto_pad can be present - if (!attributes.pads.empty()) { - test.AddAttribute("pads", attributes.pads); - } else { - test.AddAttribute("auto_pad", attributes.auto_pad); + bool use_float16, + bool weight_is_initializer = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda) { + OpTester test("NhwcConv", 1, onnxruntime::kMSDomain); + test.AddAttribute("group", attributes.group); + test.AddAttribute("kernel_shape", attributes.kernel_shape); + + if (!attributes.dilations.empty()) { + test.AddAttribute("dilations", attributes.dilations); + } + + // Only one of pads / auto_pad can be present + if (!attributes.pads.empty()) { + test.AddAttribute("pads", attributes.pads); + } else { + test.AddAttribute("auto_pad", attributes.auto_pad); + } + + if (!attributes.strides.empty()) { + test.AddAttribute("strides", attributes.strides); + } + + ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); + const char* szNames[] = {"X", "W", "B"}; + + if (use_float16) { + test.AddInput(szNames[0], input_shapes[0], ToFloat16(inputs[0])); + test.AddInput(szNames[1], input_shapes[1], ToFloat16(inputs[1]), weight_is_initializer); + if (inputs.size() == 3) { + test.AddInput(szNames[2], input_shapes[2], ToFloat16(inputs[2])); + } + test.AddOutput("Y", expected_output_shape, ToFloat16(expected_output)); + } else { + test.AddInput(szNames[0], input_shapes[0], inputs[0]); + test.AddInput(szNames[1], input_shapes[1], inputs[1], weight_is_initializer); + if (inputs.size() == 3) { + test.AddInput(szNames[2], input_shapes[2], inputs[2]); + } + test.AddOutput("Y", expected_output_shape, expected_output); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +} - if (!attributes.strides.empty()) { - test.AddAttribute("strides", attributes.strides); - } - - ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); - const char* szNames[] = {"X", "W", "B"}; - test.AddInput(szNames[0], input_shapes[0], inputs[0]); - test.AddInput(szNames[1], input_shapes[1], inputs[1], weight_is_initializer); - if (inputs.size() == 3) - test.AddInput(szNames[2], input_shapes[2], inputs[2]); - - test.AddOutput("Y", expected_output_shape, expected_output); - - std::unordered_set excluded_providers(attributes.excluded_providers); - // Disable TensorRT because weight as input is not supported - excluded_providers.insert(kTensorrtExecutionProvider); - - test.Run(expect_result, err_str, excluded_providers); +void RunNhwcConv(const NhwcConvOpAndTestAttributes& attributes, + const vector>& inputs, + const vector>& input_shapes, + const std::initializer_list& expected_output, + const vector& expected_output_shape) { + bool use_float16 = true; + bool weight_is_initializer = true; + TestNhwcConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, use_float16, weight_is_initializer); + + use_float16 = false; + weight_is_initializer = false; + TestNhwcConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, use_float16, weight_is_initializer); } } // namespace @@ -107,10 +135,7 @@ TEST(NhwcConvTest, Conv2D_2) { 0.15021422505378723f, -0.0028631272725760937f, -0.19993697106838226f, -0.03527900204062462f, 0.06516310572624207f, -0.015176207758486271f, 0.14682966470718384f, -0.02665453404188156f, -0.18779225647449493f}; - TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); - - // NNAPI/CoreML EP requires weight to be an initializer - TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); + RunNhwcConv(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } TEST(NhwcConvTest, Conv2D_Bias_1) { @@ -133,10 +158,7 @@ TEST(NhwcConvTest, Conv2D_Bias_1) { vector B_shape = {2}; auto expected_vals = {13.0f, 11.0f, 17.0f, 15.0f, 25.0f, 23.0f, 29.0f, 27.0f}; - TestNhwcConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); - - // NNAPI/CoreML EP requires weight to be an initializer - TestNhwcConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); + RunNhwcConv(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); } TEST(NhwcConvTest, Conv2D_AutoPad1) { @@ -163,10 +185,7 @@ TEST(NhwcConvTest, Conv2D_AutoPad1) { 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, 12.0f, 15.0f, 15.0f, 15.0f, 8.0f}; - TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); - - // NNAPI/CoreML EP requires weight to be an initializer - TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); + RunNhwcConv(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } TEST(NhwcConvTest, Conv2D_AutoPad2) { @@ -197,10 +216,7 @@ TEST(NhwcConvTest, Conv2D_AutoPad2) { 12.0f, 24.0f, 12.0f, 24.0f, 12.0f, 12.0f, 24.0f, 12.0f, 24.0f, 12.0f, 5.0f, 10.0f, 5.0f, 10.0f, 5.0f}; - TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); - - // NNAPI/CoreML EP requires weight to be an initializer - TestNhwcConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); + RunNhwcConv(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } } // namespace test From 3ff1fe6742c78a0dab5119cae1a4f0caaf6d154b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 2 Feb 2023 05:52:38 +0000 Subject: [PATCH 25/27] fix type warning --- .../core/providers/cuda/cudnn_common.cc | 4 +-- .../core/providers/cuda/cudnn_common.h | 4 +-- onnxruntime/core/providers/cuda/nn/conv.cc | 25 ++++++++++++++----- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 6f53bfabd75ed..4c9cbbe605a7a 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -42,7 +42,7 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat return Status::OK(); } -Status CudnnTensor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int64_t n, int64_t c, int64_t h, int64_t w) { +Status CudnnTensor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); CUDNN_RETURN_IF_ERROR(cudnnSetTensor4dDescriptor(tensor_, format, dataType, n, c, h, w)); return Status::OK(); @@ -119,7 +119,7 @@ Status CudnnFilterDescriptor::Set(gsl::span filter_dims, cudnnDat return Status::OK(); } -Status CudnnFilterDescriptor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int64_t k, int64_t c, int64_t h, int64_t w) { +Status CudnnFilterDescriptor::Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int k, int c, int h, int w) { if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateFilterDescriptor(&desc_)); diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index ebe352bcf1b7e..ba75ab4f2c029 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -19,7 +19,7 @@ class CudnnTensor final { Status Set(gsl::span input_dims, cudnnDataType_t dataType); Status Set(const CudnnTensor& x_desc, cudnnBatchNormMode_t mode); // Set 4D tensor format (for NHWC) - Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int64_t n, int64_t c, int64_t h, int64_t w); + Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int n, int c, int h, int w); operator cudnnTensorDescriptor_t() const { return tensor_; } @@ -61,7 +61,7 @@ class CudnnFilterDescriptor final { Status Set(gsl::span filter_dims, cudnnDataType_t data_typ); // Set 4D filter where k is output channels, c is input channels, h and w is rows and columns per filter. - Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int64_t k, int64_t c, int64_t h, int64_t w); + Status Set(cudnnTensorFormat_t format, cudnnDataType_t dataType, int k, int c, int h, int w); operator cudnnFilterDescriptor_t() const { return desc_; } diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 2ccb0930caa38..ab67783e5070c 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -237,7 +237,12 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (!channels_last) { ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), w_dims[0], w_dims[3], w_dims[1], w_dims[2])); + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(w_dims[0]), + static_cast(w_dims[3]), + static_cast(w_dims[1]), + static_cast(w_dims[2]))); } } @@ -247,11 +252,19 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) } if (channels_last) { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - x_dims_cudnn[0], x_dims_cudnn[3], x_dims_cudnn[1], x_dims_cudnn[2])); - - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - y_dims_cudnn[0], y_dims_cudnn[3], y_dims_cudnn[1], y_dims_cudnn[2])); + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(x_dims_cudnn[0]), + static_cast(x_dims_cudnn[3]), + static_cast(x_dims_cudnn[1]), + static_cast(x_dims_cudnn[2]))); + + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, + CudnnTensor::GetDataType(), + static_cast(y_dims_cudnn[0]), + static_cast(y_dims_cudnn[3]), + static_cast(y_dims_cudnn[1]), + static_cast(y_dims_cudnn[2]))); } else { ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); From b3a4c014a68f7c596d1196b38dd4309b054c70cc Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 2 Feb 2023 09:55:26 +0000 Subject: [PATCH 26/27] update op doc; exclude from hipify --- cmake/onnxruntime_rocm_hipify.cmake | 5 +++++ docs/OperatorKernels.md | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 376779cc01179..85246ec8bd37c 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -31,6 +31,11 @@ set(contrib_ops_excluded_files "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" + "diffusion/bias_split_gelu_impl.h" + "diffusion/bias_split_gelu_impl.cu" + "diffusion/bias_split_gelu.h" + "diffusion/bias_split_gelu.cc" + "diffusion/nhwc_conv.cc" "math/complex_mul.cc" "math/complex_mul.h" "math/complex_mul_impl.cu" diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 765d16c0f23cf..7e4eb38be780b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -812,6 +812,7 @@ Do not modify directly.* |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| +|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* extra_add:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| @@ -1089,7 +1090,8 @@ Do not modify directly.* |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| From 1fe78af9149145f081518fa806f4b985d90052a0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 2 Feb 2023 23:36:47 +0000 Subject: [PATCH 27/27] add input checks; clean debug code --- .../cpu/transformers/generation_shared.h | 5 +- .../contrib_ops/cuda/bert/attention_impl.cu | 69 ++++++++----------- .../cuda/diffusion/bias_split_gelu.cc | 5 ++ onnxruntime/core/providers/cuda/nn/conv.cc | 10 ++- .../stable_diffusion/optimize_pipeline.py | 42 ++++------- 5 files changed, 56 insertions(+), 75 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 6b092d3e99f4e..630c533c47323 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -163,14 +163,11 @@ struct IGenerationParameters { bool custom_sampling = false; }; -#ifndef NDEBUG // #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) -#endif - #ifdef DEBUG_GENERATION #define DUMP_TENSOR_LEVEL 2 #else -#define DUMP_TENSOR_LEVEL 0 // change it to 0 if want to disable dumping for code not in generation. +#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation. #endif #if DUMP_TENSOR_LEVEL > 0 diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 7731cb011c8d6..8c7ef9f919519 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -48,21 +48,6 @@ using namespace cub; #define CHECK_CUDA(expr) CUDA_RETURN_IF_ERROR(expr) #define CUDA_MEMORY_ALIGNMENT 256 -#define DUMP_ATTENTION_LEVEL 0 -#if DUMP_ATTENTION_LEVEL > 1 -#define DUMP_ATTENTION_INIT() transformers::CudaTensorConsoleDumper dumper -#define DUMP_ATTENTION(...) dumper.Print(__VA_ARGS__) -#define DUMP_ATTENTION_D(...) dumper.Print(__VA_ARGS__) -#elif DUMP_ATTENTION_LEVEL > 0 -#define DUMP_ATTENTION_INIT() transformers::CudaTensorConsoleDumper dumper -#define DUMP_ATTENTION(...) dumper.Print(__VA_ARGS__) -#define DUMP_ATTENTION_D(...) -#else -#define DUMP_ATTENTION_INIT() -#define DUMP_ATTENTION(...) -#define DUMP_ATTENTION_D(...) -#endif - namespace onnxruntime { namespace contrib { namespace cuda { @@ -283,7 +268,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, // Default format for memory efficient attention. // When there is past state, the format shal be BxNxSxH, so we disable memory efficient attention when there is past. - DUMP_ATTENTION_INIT(); + DUMP_TENSOR_INIT(); if (nullptr != data.gemm_buffer) { if (data.bias == nullptr) { // For quantized attention, bias has been added so only need transpose here. @@ -323,7 +308,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, assert(data.bias == nullptr); assert(qk_head_size == v_head_size); - DUMP_ATTENTION_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); if (use_memory_efficient_attention) { // unpack kv to BSNH. Note that there is no bias so we need not output query to q. @@ -334,8 +319,8 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, batch_size, kv_sequence_length, num_heads, qk_head_size, data.key, kv_bias, k, true, v_head_size, qkv_add_bias, 2); - DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } else { if (data.fused_cross_attention_kernel == nullptr) { @@ -347,12 +332,12 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, } else { // gemm_buffer == nullptr and not packed kv assert(data.query != nullptr && data.key != nullptr && data.value != nullptr && data.bias != nullptr); - DUMP_ATTENTION_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_ATTENTION_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_ATTENTION_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); - DUMP_ATTENTION_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); if (data.fused_cross_attention_kernel != nullptr) { assert(qk_head_size == v_head_size); @@ -374,9 +359,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, num_heads, qk_head_size, v_head_size, data.bias, data.query, data.key, data.value, q, k, v); - DUMP_ATTENTION_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } #endif @@ -389,7 +374,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, batch_size, sequence_length, num_heads, qk_head_size, data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_ATTENTION_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); + DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); qkv_format = AttentionQkvFormat::QKV_BSN3H; } else { // unfused kernel @@ -414,9 +399,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.value, data.bias + 2 * num_heads * qk_head_size, v, true, -1); - DUMP_ATTENTION_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_ATTENTION_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size); - DUMP_ATTENTION_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size); + DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size); qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } } @@ -517,7 +502,7 @@ Status QkvToContext( } // Q, K and V are ready now - DUMP_ATTENTION_INIT(); + DUMP_TENSOR_INIT(); if (data.fused_cross_attention_kernel != nullptr) { assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); @@ -525,7 +510,7 @@ Status QkvToContext( LaunchTrtSequenceOffset(q_sequence_offset, nullptr, batch_size, sequence_length, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_ATTENTION_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. @@ -535,7 +520,7 @@ Status QkvToContext( LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, kv_sequence_length, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_ATTENTION_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); + DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); @@ -562,7 +547,7 @@ Status QkvToContext( kv_sequence_length, // sequence length of KV stream); - DUMP_ATTENTION("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); return Status::OK(); } @@ -588,11 +573,11 @@ Status QkvToContext( if (use_fused_kernel) { assert(qkv_format == AttentionQkvFormat::QKV_BSN3H); fused_fp16_runner->run(qkv, sequence_offset, data.output, stream); - DUMP_ATTENTION("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); } else { assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_ATTENTION("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size); } return Status::OK(); } @@ -631,7 +616,7 @@ Status QkvToContext( p.stream = stream; run_memory_efficient_attention(p); - DUMP_ATTENTION("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); return Status::OK(); } #endif @@ -663,7 +648,7 @@ Status QkvToContext( q, qk_head_size, sequence_length * qk_head_size, &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_ATTENTION_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); @@ -697,7 +682,7 @@ Status QkvToContext( scratch1, scratch2, parameters.is_unidirectional)); } - DUMP_ATTENTION_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = qkv; @@ -711,7 +696,7 @@ Status QkvToContext( // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, max_threads_per_block, false, temp_output, data.output); - DUMP_ATTENTION("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size); return result; } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc index 265cbb79e1801..2b13cdbd803ef 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -39,6 +39,11 @@ Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { "input is expected to have 3 dimensions, got ", input_dims.size()); } + if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "hidden size should be 2560, 5120 or 10240, got ", input_dims[2]); + } + const Tensor* bias = context->Input(1); const auto& bias_dims = bias->Shape().GetDims(); if (bias_dims.size() != 1) { diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index ab67783e5070c..b0df77db96744 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -52,7 +52,7 @@ cudnnStatus_t GetWorkspaceSize(cudnnHandle_t handle, const CudnnConvState& s, const cudnnConvolutionFwdAlgo_t* algo, int n_algo) { - // TODO: get maximum available size from memory areana + // TODO: get maximum available size from memory arena size_t free, total; CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); // Assuming 10% of fragmentation @@ -99,6 +99,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const TensorShape& w_shape = W->Shape(); auto w_dims = w_shape.AsShapeVector(); s_.w_data = reinterpret_cast(W->Data()); + + // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. + constexpr bool channels_last = NHWC; + if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + } + // set B if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -125,7 +132,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.cached_benchmark_results.clear(); } - constexpr bool channels_last = NHWC; ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); TensorShapeVector kernel_shape; diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 26b41c575b165..0979f0d2ddcb5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -34,12 +34,11 @@ import coloredlogs sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from fusion_options import FusionOptions from optimizer import optimize_model # noqa: E402 logger = logging.getLogger(__name__) -DEBUG = True - def optimize_stable_diffusion_onnx_pipeline( source_dir: Path, target_dir: Path, overwrite: bool, use_external_data_format: bool, float16: bool @@ -57,7 +56,7 @@ def optimize_stable_diffusion_onnx_pipeline( RuntimeError: input onnx model does not exist RuntimeError: output onnx model path existed """ - dirs_with_onnx = ["unet"] if DEBUG else ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"] + dirs_with_onnx = ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"] for name in dirs_with_onnx: onnx_model_path = source_dir / name / "model.onnx" @@ -73,24 +72,25 @@ def optimize_stable_diffusion_onnx_pipeline( # Graph fusion before fp16 conversion, otherwise they cannot be fused later. # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. logger.info(f"optimize {onnx_model_path}...") + + fusion_options = FusionOptions("unet") + # packed kv requires compute capacity >= 7.5 (like T4, A100, RTX 2060~4090. See https://developer.nvidia.com/cuda-gpus) + # Suggest to disable it if you are using older GPU like V100, RTX 1060/1070/1080, or using float32 model. + fusion_options.enable_packed_kv = float16 + m = optimize_model( str(onnx_model_path), model_type="unet", num_heads=num_heads, hidden_size=hidden_size, opt_level=0, - optimization_options=None, + optimization_options=fusion_options, use_gpu=False, ) if float16: - # VAE-decoder in fp16 reduced quality thus we exclude it here - # TODO: enable mixed precision conversion for VAE-decoder. - if name != "vae_decoder": - logger.info(f"convert to float16 ...") - m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize", "GroupNorm"]) - else: - logger.info("skip convert vae_decoder to fp16.") + logger.info("convert %s to float16 ...", name) + m.convert_float_to_float16(op_block_list=["RandomNormalLike", "Resize", "GroupNorm"]) optimized_model_path = target_dir / name / "model.onnx" output_dir = optimized_model_path.parent @@ -103,7 +103,7 @@ def optimize_stable_diffusion_onnx_pipeline( output_dir.mkdir(parents=True, exist_ok=True) m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format) - logger.info(f"{onnx_model_path} => {optimized_model_path}") + logger.info("%s => %s", onnx_model_path, optimized_model_path) def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): @@ -118,19 +118,7 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): RuntimeError: source path does not exist RuntimeError: output path exists but overwrite is false. """ - extra_dirs = ( - [ - "vae_encoder", - "vae_decoder", - "text_encoder", - "safety_checker", - "scheduler", - "tokenizer", - "feature_extractor", - ] - if DEBUG - else ["scheduler", "tokenizer", "feature_extractor"] - ) + extra_dirs = ["scheduler", "tokenizer", "feature_extractor"] for name in extra_dirs: source_path = source_dir / name @@ -148,7 +136,7 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): shutil.rmtree(target_path) shutil.copytree(source_path, target_path) - logger.info(f"{source_path} => {target_path}") + logger.info("%s => %s", source_path, target_path) extra_files = ["model_index.json"] for name in extra_files: @@ -162,7 +150,7 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): raise RuntimeError(f"output path existed: {target_path}") os.remove(target_path) shutil.copyfile(source_path, target_path) - logger.info(f"{source_path} => {target_path}") + logger.info("%s => %s", source_path, target_path) def parse_arguments():