-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stable Diffusion CUDA optimizations Part 2 (#14597)
### Description This is a follow-up of #14428 for Stable Diffusion CUDA optimizations: (1) use NchwConv to replace Conv in onnx graph and add Tranpose nodes accordingly (2) reduce sequential Transpose nodes to at most one. (3) symbolic shape infer of NchwConv (4) fix add bias transpose which causes CUDA error (launching more than 1024 threads per block) in inferencing fp32 model. (5) add models (bert, bart, stable_diffusion subdirectories) to package; (6) remove option --disable_channels_last Note that (1) We can add a few graph transformations to reduce Transpose nodes further. It is not done in this PR due to time limit. (2) Stable diffusion 2.1 model outputs black images. It seems that forcing Attention to float32 could avoid the issue. However it is much slow to use float32 Attention. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
- Loading branch information
Showing
13 changed files
with
386 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
|
||
from logging import getLogger | ||
from typing import List | ||
|
||
from fusion_base import Fusion | ||
from onnx import TensorProto, helper, numpy_helper | ||
from onnx_model import OnnxModel | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class FusionNhwcConv(Fusion): | ||
"""Convert Conv to NhwcConv""" | ||
|
||
def __init__(self, model: OnnxModel, update_weight=False): | ||
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv") | ||
self.update_weight = update_weight | ||
|
||
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): | ||
"""Append a Transpose node after an input""" | ||
node_name = self.model.create_node_name("Transpose") | ||
|
||
if output_name is None: | ||
output_name = node_name + "_out" + "-" + input_name | ||
|
||
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) | ||
transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) | ||
|
||
return transpose_node | ||
|
||
def fuse(self, conv, input_name_to_nodes, output_name_to_node): | ||
# Add Transpose node to convert input from NCHW to NHWC | ||
input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1]) | ||
|
||
nhwc_conv_input = input_transpose_node.output[0] | ||
|
||
# Create a tensor for transposed weights (already in NHWC format). | ||
node_name = self.model.create_node_name("NhwcConv") | ||
|
||
# Make sure the weights is 4D | ||
weight_tensor = self.model.get_initializer(conv.input[1]) | ||
if weight_tensor is None: | ||
return | ||
weight = numpy_helper.to_array(weight_tensor) | ||
if len(weight.shape) != 4: | ||
return | ||
|
||
if self.update_weight: | ||
# Transpose weights from NCHW to NHWC | ||
weight = weight.transpose(0, 2, 3, 1) | ||
|
||
weight_name = node_name + "_weight_NHWC" | ||
nhwc_weight = helper.make_tensor( | ||
name=weight_name, | ||
data_type=TensorProto.FLOAT, | ||
dims=list(weight.shape), | ||
vals=weight.flatten().tolist(), | ||
) | ||
self.model.add_initializer(nhwc_weight, self.this_graph_name) | ||
weight_transpose_node = None | ||
else: | ||
weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1]) | ||
weight_name = weight_transpose_node.output[0] | ||
|
||
nhwc_output_name = node_name + "_out" + "-" + conv.output[0] | ||
nhwc_conv = helper.make_node( | ||
"NhwcConv", | ||
inputs=[nhwc_conv_input, weight_name] + conv.input[2:], | ||
outputs=[nhwc_output_name], | ||
name=node_name + "-" + conv.name, | ||
) | ||
nhwc_conv.attribute.extend(conv.attribute) | ||
nhwc_conv.domain = "com.microsoft" | ||
|
||
output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0]) | ||
|
||
self.nodes_to_remove.append(conv) | ||
|
||
nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node] | ||
if weight_transpose_node: | ||
nodes_to_add.append(weight_transpose_node) | ||
for node in nodes_to_add: | ||
self.node_name_to_graph_name[node.name] = self.this_graph_name | ||
self.nodes_to_add.extend(nodes_to_add) | ||
|
||
self.increase_counter("NhwcConv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.