Skip to content

Commit

Permalink
Stable Diffusion CUDA optimizations Part 2 (#14597)
Browse files Browse the repository at this point in the history
### Description
This is a follow-up of
#14428 for Stable Diffusion
CUDA optimizations:
(1) use NchwConv to replace Conv in onnx graph and add Tranpose nodes
accordingly
(2) reduce sequential Transpose nodes to at most one.
(3) symbolic shape infer of NchwConv
(4) fix add bias transpose which causes CUDA error (launching more than
1024 threads per block) in inferencing fp32 model.
(5) add models (bert, bart, stable_diffusion subdirectories) to package;
(6) remove option --disable_channels_last

Note that 
(1) We can add a few graph transformations to reduce Transpose nodes
further. It is not done in this PR due to time limit.
(2) Stable diffusion 2.1 model outputs black images. It seems that
forcing Attention to float32 could avoid the issue. However it is much
slow to use float32 Attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
tianleiwu authored and rui-ren committed Feb 7, 2023
1 parent fc790c1 commit b0a2990
Show file tree
Hide file tree
Showing 13 changed files with 386 additions and 61 deletions.
21 changes: 21 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,21 @@ file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DE
file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py"
)
file(GLOB onnxruntime_python_transformers_models_bart_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bart/*.py"
)
file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bert/*.py"
)
file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py"
)
file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py"
)
file(GLOB onnxruntime_python_transformers_models_t5_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/t5/*.py"
)
Expand Down Expand Up @@ -526,8 +535,11 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/tools/ort_format_model/ort_flatbuffers_py
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/quantization/operators
Expand Down Expand Up @@ -606,12 +618,21 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_bart_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_bert_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_gpt2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_longformer_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_stable_diffusion_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_t5_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5/
Expand Down
24 changes: 13 additions & 11 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count) {
assert(num_heads <= max_threads_per_block);
const dim3 grid(sequence_length, batch_size, num_matrices);
if (qk_head_size * num_heads <= max_threads_per_block) {
const dim3 block(qk_head_size, num_heads, 1);
Expand All @@ -544,7 +545,7 @@ void InvokeAddBiasTranspose(
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
}
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
if (format == 2) {
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
} else if (format == 1) {
Expand Down Expand Up @@ -577,7 +578,7 @@ void LaunchAddBiasTranspose(
const half* input, const half* biases, half* output,
bool enable_half4, const int v_head_size, half* qkv_add_bias, int total_matrix_count) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) {
const int H = qk_head_size / 4;
const int H_v = v_head_size / 4;
const Half4* input2 = reinterpret_cast<const Half4*>(input);
Expand All @@ -587,7 +588,7 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2,
qkv_add_bias2, H_v, total_matrix_count);
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) {
} else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) {
const int H = qk_head_size / 2;
const int H_v = v_head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
Expand All @@ -612,7 +613,7 @@ void LaunchAddBiasTranspose(
const float* input, const float* biases, float* output,
bool /*enable_half4*/, const int v_head_size, float* qkv_add_bias, int total_matrix_count) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4))) {
const int H = qk_head_size / 4;
const float4* input2 = reinterpret_cast<const float4*>(input);
const float4* biases2 = reinterpret_cast<const float4*>(biases);
Expand All @@ -622,7 +623,7 @@ void LaunchAddBiasTranspose(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2,
qkv_add_bias2, v_head_size / 4, total_matrix_count);
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size & 1)) {
} else if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1))) {
const int H = qk_head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
const float2* biases2 = reinterpret_cast<const float2*>(biases);
Expand Down Expand Up @@ -654,7 +655,7 @@ void InvokeAddBiasTransposeTrt(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, key, value, biases, output);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, key, value, biases, output);
}
} else { // cross attention
Expand All @@ -666,7 +667,7 @@ void InvokeAddBiasTransposeTrt(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, biases, output);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, biases, output);
}
}
Expand All @@ -680,7 +681,7 @@ void InvokeAddBiasTransposeTrt(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrtKV<T><<<grid, block, 0, stream>>>(key, value, biases, packed_kv);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtKVLarge<T><<<grid, block, 0, stream>>>(head_size, key, value, biases, packed_kv);
}
}
Expand Down Expand Up @@ -737,6 +738,7 @@ void InvokeAddBias(
const int batch_size, const int sequence_length, const int kv_sequence_length,
const int num_heads, const int head_size, const int v_head_size,
const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v) {
assert(num_heads <= max_threads_per_block);
constexpr int num_matrices = 1;
// Q
{
Expand All @@ -745,7 +747,7 @@ void InvokeAddBias(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(query, biases, q);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, query, biases, q);
}
}
Expand All @@ -758,7 +760,7 @@ void InvokeAddBias(
const dim3 block(head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(key, biases_k, k);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, key, biases_k, k);
}
}
Expand All @@ -772,7 +774,7 @@ void InvokeAddBias(
const dim3 block(v_head_size, num_heads, 1);
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(value, biases_v, v);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(v_head_size, value, biases_v, v);
}
}
Expand Down
33 changes: 26 additions & 7 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"GroupNorm": self._infer_GroupNorm,
"BiasSplitGelu": self._infer_BiasSplitGelu,
"NhwcConv": self._infer_NhwcConv,
}
self.aten_op_dispatcher_ = {
"embedding": self._infer_Gather,
Expand Down Expand Up @@ -438,6 +439,7 @@ def _onnx_infer_single_node(self, node):
"MultiHeadAttention",
"GroupNorm",
"BiasSplitGelu",
"NhwcConv",
]

if not skip_infer:
Expand Down Expand Up @@ -619,13 +621,13 @@ def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
def _new_symbolic_shape(self, rank, node, out_idx=0):
return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]

def _compute_conv_pool_shape(self, node):
def _compute_conv_pool_shape(self, node, channels_last=False):
sympy_shape = self._get_sympy_shape(node, 0)
if len(node.input) > 1:
W_shape = self._get_sympy_shape(node, 1)
rank = len(W_shape) - 2 # number of spatial axes
kernel_shape = W_shape[-rank:]
sympy_shape[1] = W_shape[0]
kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
sympy_shape[3 if channels_last else 1] = W_shape[0]
else:
W_shape = None
kernel_shape = get_attribute(node, "kernel_shape")
Expand All @@ -634,13 +636,17 @@ def _compute_conv_pool_shape(self, node):
assert len(sympy_shape) == rank + 2

# only need to symbolic shape inference if input has symbolic dims in spatial axes
is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]]
spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
is_symbolic_dims = [not is_literal(i) for i in spatial_shape]

if not any(is_symbolic_dims):
shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
if len(shape) > 0:
assert len(sympy_shape) == len(shape)
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
if channels_last:
sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
else:
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
return sympy_shape

dilations = get_attribute(node, "dilations", [1] * rank)
Expand Down Expand Up @@ -671,7 +677,7 @@ def _compute_conv_pool_shape(self, node):

ceil_mode = get_attribute(node, "ceil_mode", 0)
for i in range(rank):
effective_input_size = sympy_shape[-rank + i]
effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
if len(total_pads) > 0:
effective_input_size = effective_input_size + total_pads[i]
if ceil_mode:
Expand All @@ -680,7 +686,7 @@ def _compute_conv_pool_shape(self, node):
)
else:
strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i]
sympy_shape[-rank + i] = strided_kernel_positions + 1
sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
return sympy_shape

def _check_merged_dims(self, dims, allow_broadcast=True):
Expand Down Expand Up @@ -914,6 +920,18 @@ def _infer_Conv(self, node):
)
)

def _infer_NhwcConv(self, node):
sympy_shape = self._compute_conv_pool_shape(node, channels_last=True)
self._update_computed_dims(sympy_shape)
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(sympy_shape),
)
)

def _infer_Einsum(self, node):
# ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275
equation = get_attribute(node, "equation")
Expand Down Expand Up @@ -2455,6 +2473,7 @@ def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=F
all_shapes_inferred = symbolic_shape_inference._infer_impl()
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
raise Exception("Incomplete symbolic shape inference")
return symbolic_shape_inference.out_mp_

Expand Down
90 changes: 90 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_nhwc_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from logging import getLogger
from typing import List

from fusion_base import Fusion
from onnx import TensorProto, helper, numpy_helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionNhwcConv(Fusion):
"""Convert Conv to NhwcConv"""

def __init__(self, model: OnnxModel, update_weight=False):
super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
self.update_weight = update_weight

def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
node_name = self.model.create_node_name("Transpose")

if output_name is None:
output_name = node_name + "_out" + "-" + input_name

transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])

return transpose_node

def fuse(self, conv, input_name_to_nodes, output_name_to_node):
# Add Transpose node to convert input from NCHW to NHWC
input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1])

nhwc_conv_input = input_transpose_node.output[0]

# Create a tensor for transposed weights (already in NHWC format).
node_name = self.model.create_node_name("NhwcConv")

# Make sure the weights is 4D
weight_tensor = self.model.get_initializer(conv.input[1])
if weight_tensor is None:
return
weight = numpy_helper.to_array(weight_tensor)
if len(weight.shape) != 4:
return

if self.update_weight:
# Transpose weights from NCHW to NHWC
weight = weight.transpose(0, 2, 3, 1)

weight_name = node_name + "_weight_NHWC"
nhwc_weight = helper.make_tensor(
name=weight_name,
data_type=TensorProto.FLOAT,
dims=list(weight.shape),
vals=weight.flatten().tolist(),
)
self.model.add_initializer(nhwc_weight, self.this_graph_name)
weight_transpose_node = None
else:
weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1])
weight_name = weight_transpose_node.output[0]

nhwc_output_name = node_name + "_out" + "-" + conv.output[0]
nhwc_conv = helper.make_node(
"NhwcConv",
inputs=[nhwc_conv_input, weight_name] + conv.input[2:],
outputs=[nhwc_output_name],
name=node_name + "-" + conv.name,
)
nhwc_conv.attribute.extend(conv.attribute)
nhwc_conv.domain = "com.microsoft"

output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0])

self.nodes_to_remove.append(conv)

nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node]
if weight_transpose_node:
nodes_to_add.append(weight_transpose_node)
for node in nodes_to_add:
self.node_name_to_graph_name[node.name] = self.this_graph_name
self.nodes_to_add.extend(nodes_to_add)

self.increase_counter("NhwcConv")
7 changes: 3 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,15 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
shape_nodes.extend([path2[-1], path3[-1]])
shape.append(-1)
elif len(concat_node.input) > 2:
concat_2 = self.model.get_initializer(concat_node.input[2])
if concat_2 is None:
concat_value = self.model.get_constant_value(concat_node.input[2])
if concat_value is None:
return
concat_value = numpy_helper.to_array(concat_2)
if isinstance(concat_value, np.ndarray):
shape.extend(concat_value.tolist())
else:
shape.append(concat_value)

if len(concat_node.input) == 4 and self.model.get_initializer(concat_node.input[3]) is None:
if len(concat_node.input) == 4 and self.model.get_constant_value(concat_node.input[3]) is None:
if -1 in shape:
return

Expand Down
Loading

0 comments on commit b0a2990

Please sign in to comment.