Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion CUDA Optimizations #14428

Merged
merged 28 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6d8402e
Add benchmark
tianleiwu Jan 25, 2023
0ac952e
Add GroupNorm fusion
tianleiwu Jan 25, 2023
c1de9fa
Merge branch 'main' into tlwu/optimize_sd
tianleiwu Jan 25, 2023
29c47dc
[CUDA] Add GroupNormalization operator
tianleiwu Jan 25, 2023
5f40aa8
Add Cast for fp16 group_norm
tianleiwu Jan 26, 2023
a4c4302
Add SplitGelu fusion
tianleiwu Jan 27, 2023
f722c5a
support float type in GroupNorm
tianleiwu Jan 27, 2023
4a7bf0d
Add SplitGelu operator
tianleiwu Jan 27, 2023
98b90ca
format
tianleiwu Jan 28, 2023
ea69aec
format
tianleiwu Jan 28, 2023
9eacd84
misc
tianleiwu Jan 29, 2023
c566679
update group norm test data to NHWC
tianleiwu Jan 29, 2023
a9ebeec
Fuse Bias and SplitGelu
tianleiwu Jan 29, 2023
53a539f
update bias split gelu
tianleiwu Jan 29, 2023
a0c4957
update GroupNorm doc
tianleiwu Jan 30, 2023
82383dc
packed kv in cross attention
tianleiwu Jan 31, 2023
966b3e7
fix pyright warnings
tianleiwu Jan 31, 2023
4a8583e
Add unit test of bias split gelu
tianleiwu Jan 31, 2023
982663a
fix typo
tianleiwu Jan 31, 2023
73045bb
fix code scanning warnings
tianleiwu Jan 31, 2023
86d5795
fix code scanning warnings
tianleiwu Jan 31, 2023
efa6d4f
address review feedback
tianleiwu Jan 31, 2023
7a75ce1
Add NhwcConv
tianleiwu Feb 1, 2023
f4d4103
fix training api build error
tianleiwu Feb 1, 2023
55a7468
Add float16 test
tianleiwu Feb 1, 2023
3ff1fe6
fix type warning
tianleiwu Feb 2, 2023
b3a4c01
update op doc; exclude from hipify
tianleiwu Feb 2, 2023
1fe78af
add input checks; clean debug code
tianleiwu Feb 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ set(contrib_ops_excluded_files
"bert/tensorrt_fused_multihead_attention/*"
"bert/transformer_common.h"
"bert/transformer_common.cc"
"diffusion/group_norm.h"
"diffusion/group_norm.cc"
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"math/complex_mul.cc"
"math/complex_mul.h"
"math/complex_mul_impl.cu"
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler);
Expand Down Expand Up @@ -192,6 +193,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupNorm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler)>,
Expand Down
106 changes: 106 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/diffusion/group_norm.h"
#include "contrib_ops/cuda/diffusion/group_norm_impl.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GroupNorm, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
GroupNorm<T>);

REGISTER_KERNEL_TYPED(MLFloat16);

using namespace ONNX_NAMESPACE;

template <typename T>
GroupNorm<T>::GroupNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);

int64_t num_groups;
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("groups", &num_groups).IsOK());
ORT_ENFORCE(num_groups >= 0);
num_groups_ = static_cast<int>(num_groups);


ORT_ENFORCE(op_kernel_info.GetAttr<bool>("swish", &swish_).IsOK());
}

template <typename T>
Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* gamma = context->Input<Tensor>(1);
const Tensor* beta = context->Input<Tensor>(2);
Tensor* output = context->Output(0, input->Shape());

const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input is expected to have 4 dimensions, got ", input_dims.size());
}

const auto& gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
}
if (gamma_dims[0] != input_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of gamma and input does not match");
}

const auto& beta_dims = beta->Shape().GetDims();
if (beta_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"beta is expected to have 1 dimension, got ", beta_dims.size());
}
if (beta_dims[0] != input_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of beta and input does not match");
}

int batch_size = static_cast<int>(input_dims[0]);
int num_channels = static_cast<int>(input_dims[1]);
int height = static_cast<int>(input_dims[2]);
int width = static_cast<int>(input_dims[3]);

if (num_channels % num_groups_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"number of channels should be divisiable by num_groups");
}

auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());

tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
typedef typename ToCudaType<T>::MappedType CudaT;

return LaunchGroupNormKernel<CudaT>(
Stream(context),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const float*>(gamma->Data<T>()),
reinterpret_cast<const float*>(beta->Data<T>()),
reinterpret_cast<float*>(workspace.get()),
epsilon_,
batch_size,
num_channels,
height,
width,
num_groups_,
swish_);
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
28 changes: 28 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

using namespace onnxruntime::cuda;

template <typename T>
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
class GroupNorm final : public CudaKernel {
public:
GroupNorm(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;

private:
bool swish_;
float epsilon_;
int num_groups_;
};

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading