From 16f1b4c57c895bc1944bfe54a9939ff1ecac4fc1 Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Mon, 6 Feb 2023 11:16:53 +0000 Subject: [PATCH] add tuning --- cmake/onnxruntime_rocm_hipify.cmake | 1 + .../core/providers/rocm/math/softmax.cc | 30 +++-- .../core/providers/rocm/math/softmax.h | 74 ++++++++++++ .../core/providers/rocm/math/softmax_impl.cu | 106 ++++++------------ .../rocm/math/softmax_tunable_op.cuh | 62 +++++++++- .../rocm/math/softmax_warpwise_impl.cuh | 33 +++--- .../kernel_explorer/kernels/rocm/softmax.cu | 29 +++++ 7 files changed, 233 insertions(+), 102 deletions(-) create mode 100644 onnxruntime/core/providers/rocm/math/softmax.h diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index f13d95474cd94..2c13b5cbb56eb 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -124,6 +124,7 @@ set(provider_excluded_files "math/softmax_common.cc" "math/softmax_common.h" "math/softmax.cc" + "math/softmax.h" "nn/conv.cc" "nn/conv.h" "nn/conv_transpose.cc" diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index 22bcaecf34f65..013c810f843bc 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -11,13 +11,14 @@ namespace onnxruntime { namespace rocm { -template +template Status SoftMaxComputeHelper( hipStream_t stream, const T* X, const TensorShape& input_shape, T* Y, - int64_t axis) { + int64_t axis, + RocmTuningContext* tuning_ctx) { typedef typename ToHipType::MappedType HipT; int64_t N = input_shape.SizeToDimension(axis); @@ -26,17 +27,20 @@ Status SoftMaxComputeHelper( auto X_data = reinterpret_cast(X); if (D <= 1024 && D * sizeof(T) <= 4096) { - return dispatch_warpwise_softmax_forward, is_log_softmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); + return dispatch_warpwise_softmax_forward, IsLogSoftmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), + gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); } - return dispatch_blockwise_softmax_forward, is_log_softmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), - gsl::narrow_cast(N)); + return dispatch_blockwise_softmax_forward, IsLogSoftmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), + gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); } -#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \ - template Status SoftMaxComputeHelper(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, int64_t axis); \ - template Status SoftMaxComputeHelper(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, int64_t axis); +#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \ + template Status SoftMaxComputeHelper(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, \ + int64_t axis, RocmTuningContext* tuning_ctx); \ + template Status SoftMaxComputeHelper(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, \ + int64_t axis, RocmTuningContext* tuning_ctx); SPECIALIZED_SOFTMAX_HELPER_IMPL(float) // MIOpen double data type not supported @@ -173,11 +177,13 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { if (log_softmax_) { status = SoftMaxComputeHelper(Stream(ctx), X_data, *compute_input_shape, Y_data, is_transpose_required ? static_cast(rank) - 1 - : static_cast(axis)); + : static_cast(axis), + GetTuningContext()); } else { status = SoftMaxComputeHelper(Stream(ctx), X_data, *compute_input_shape, Y_data, is_transpose_required ? static_cast(rank) - 1 - : static_cast(axis)); + : static_cast(axis), + GetTuningContext()); } if (!status.IsOK()) diff --git a/onnxruntime/core/providers/rocm/math/softmax.h b/onnxruntime/core/providers/rocm/math/softmax.h new file mode 100644 index 0000000000000..89f75c5eb10c2 --- /dev/null +++ b/onnxruntime/core/providers/rocm/math/softmax.h @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/gsl.h" +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace rocm { + +using tunable::RocmTuningContext; + +template +Status SoftMaxComputeHelper( + hipStream_t stream, + const T* input, + const TensorShape& shape, + T* Y, + int64_t axis, + RocmTuningContext* tuning_ctx = nullptr); + +template +Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements, + int softmax_elements_stride, int batch_count, + RocmTuningContext* tuning_ctx = nullptr); + +template +Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements, + int input_stride, int output_stride, int batch_count, + RocmTuningContext* tuning_ctx = nullptr); + +template +class Softmax final : public RocmKernel { + public: + Softmax(const OpKernelInfo& info) : RocmKernel{info} { + const auto& node = info.node(); + opset_ = node.SinceVersion(); + + int64_t axis; + Status status = info.GetAttr("axis", &axis); + + if (status.IsOK()) { + axis_ = gsl::narrow_cast(axis); + } else { + if (opset_ < 13) { + axis_ = 1; // opset-12 and below, the default axis value is 1 + } else { + axis_ = -1; // opset-13, the default axis value is -1 + } + } + + log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax"; + + // We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method + // TODO: Clean up the ROCMExecutionProvider interface to avoid this + rocm_ep_ = const_cast( + static_cast(info.GetExecutionProvider())); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t axis_; + bool log_softmax_; + int opset_; + + // We need to access to the ROCM EP instance to get the rocblas handle to use + // for transposing(if applicable) + ROCMExecutionProvider* rocm_ep_; +}; + +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu index ad36240926f54..462051c8cd28c 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_impl.cu +++ b/onnxruntime/core/providers/rocm/math/softmax_impl.cu @@ -20,9 +20,9 @@ #include "hip/hip_runtime.h" #include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/math/softmax_warpwise_impl.cuh" -#include "core/providers/rocm/math/softmax_blockwise_impl.cuh" #include "core/providers/rocm/math/softmax.h" +#include "core/providers/rocm/math/softmax_common.h" +#include "core/providers/rocm/math/softmax_tunable_op.cuh" #include @@ -30,58 +30,24 @@ namespace onnxruntime { namespace rocm { template -Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements, int softmax_elements_stride, int batch_count) { - if (softmax_elements == 0) { - return Status::OK(); - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = 1; - // use 256 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 256; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \ - case L2E: \ - softmax_warp_forward \ - <<>>(dst, src, batch_count, \ - softmax_elements_stride, softmax_elements); \ - break; - LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 - LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 - LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 - LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 - LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 - LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 - LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 - LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 - LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 - LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 - LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024 - default: - break; - } +Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements, + int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx) { + SoftmaxParams params(tuning_ctx, stream, dst, src, softmax_elements, softmax_elements_stride, + softmax_elements_stride, batch_count, IsLogSoftmax); + if (tuning_ctx->IsTunableOpEnabled()) { + static SoftmaxTunableOp op; + return op(¶ms); } - return HIP_CALL(hipGetLastError()); + return SoftmaxWarpwiseStaticSelection(¶ms); } -#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT) \ - template Status dispatch_warpwise_softmax_forward( \ - hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \ - int softmax_elements_stride, int batch_count); \ - template Status dispatch_warpwise_softmax_forward( \ - hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \ - int softmax_elements_stride, int batch_count); +#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT) \ + template Status dispatch_warpwise_softmax_forward( \ + hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \ + int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx); \ + template Status dispatch_warpwise_softmax_forward( \ + hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \ + int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx); SPECIALIZED_SOFTMAX_IMPL(float, float, float) SPECIALIZED_SOFTMAX_IMPL(half, half, float) @@ -89,30 +55,26 @@ SPECIALIZED_SOFTMAX_IMPL(double, double, double) SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float) template -Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements, - int input_stride, int output_stride, int batch_count) { - dim3 grid(batch_count); - constexpr int ILP = sizeof(float4) / sizeof(InputT); - dim3 block = SoftMax_getBlockSize(ILP, softmax_elements); - if (IsLogSoftmax) { - softmax_block_forward - <<>>(output, const_cast(input), - softmax_elements, input_stride, output_stride); - } else { - softmax_block_forward - <<>>(output, const_cast(input), - softmax_elements, input_stride, output_stride); +Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, + const InputT* input, int softmax_elements, + int input_stride, int output_stride, + int batch_count, RocmTuningContext* tuning_ctx) { + SoftmaxParams params(tuning_ctx, stream, output, input, softmax_elements, input_stride, + output_stride, batch_count, IsLogSoftmax); + if (tuning_ctx->IsTunableOpEnabled()) { + static SoftmaxTunableOp op; + return op(¶ms); } - return HIP_CALL(hipGetLastError()); + return SoftmaxBlockwiseStaticSelection(¶ms); } -#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT) \ - template Status dispatch_blockwise_softmax_forward( \ - hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \ - int input_stride, int output_stride, int batch_count); \ - template Status dispatch_blockwise_softmax_forward( \ - hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \ - int input_stride, int output_stride, int batch_count); +#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT) \ + template Status dispatch_blockwise_softmax_forward( \ + hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \ + int input_stride, int output_stride, int batch_count, RocmTuningContext* tuning_ctx); \ + template Status dispatch_blockwise_softmax_forward( \ + hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \ + int input_stride, int output_stride, int batch_count, RocmTuningContext* tuning_ctx); SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float) SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, half, float) diff --git a/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh b/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh index 7347cd2c035b9..8c761c5b83a3a 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh +++ b/onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh @@ -33,6 +33,55 @@ Status SoftmaxBlockwiseOp(const SoftmaxParams* params) { return HIP_CALL(hipGetLastError()); } +template +Status SoftmaxWarpwiseStaticSelection(const SoftmaxParams* params) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !(params->input_stride <= 1024 && params->input_stride * sizeof(InputT) <= 4096)); + if (params->softmax_elements == 0) { + return Status::OK(); + } else { + int log2_elements = log2_ceil(params->softmax_elements); + const int next_power_of_two = 1 << log2_elements; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = 1; + // use 256 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 256; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (params->batch_count + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \ + case L2E: \ + softmax_warp_forward \ + <<stream>>>(params->output, params->input, params->batch_count, \ + params->input_stride, params->softmax_elements, \ + params->is_log_softmax); \ + break; + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024 + default: + break; + } + } + return HIP_CALL(hipGetLastError()); +} + template Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams* params) { dim3 grid(params->batch_count); @@ -53,9 +102,20 @@ Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams* par } template -class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp> { +Status SoftmaxStaticSelection(const SoftmaxParams* params) { + auto status = SoftmaxWarpwiseStaticSelection(params); + if (!status.IsOK()) { + status = SoftmaxBlockwiseStaticSelection(params); + } + return status; +} + +template +class SoftmaxTunableOp : public tunable::TunableOp> { public: SoftmaxTunableOp() { + this->RegisterOp(SoftmaxStaticSelection); + this->RegisterOp(SoftmaxWarpwiseStaticSelection); this->RegisterOp(SoftmaxBlockwiseStaticSelection); this->RegisterOp(SoftmaxBlockwiseOp); this->RegisterOp(SoftmaxBlockwiseOp); diff --git a/onnxruntime/core/providers/rocm/math/softmax_warpwise_impl.cuh b/onnxruntime/core/providers/rocm/math/softmax_warpwise_impl.cuh index 2cfddce972e04..f30bb970e0177 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_warpwise_impl.cuh +++ b/onnxruntime/core/providers/rocm/math/softmax_warpwise_impl.cuh @@ -1,18 +1,18 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ // The code below is mostly copied from Pytorch PersistentSoftmax.cuh @@ -56,7 +56,6 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { } } - // The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension. // Each sample contains element_count scalar elements. element_count can be any integer value <= 1024. // The template arguments have the following meaning: @@ -74,8 +73,8 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { // input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor. // input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor. -template -__global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batch_size, int stride, int element_count) { +template +__global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batch_size, int stride, int element_count, bool is_log_softmax) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE; diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu index 8128f73243804..c514c6515709d 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu @@ -50,6 +50,32 @@ class SoftmaxBlockwise : public IKernelExplorer { std::string type_string_{}; }; +template +class SoftmaxWarpwiseStaticSelection : public IKernelExplorer { + public: + SoftmaxWarpwiseStaticSelection(DeviceArray& output, DeviceArray& input, int softmax_elements, + int input_stride, int output_stride, int batch_count, bool is_log_softmax) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(input.ptr()), + softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {} + + void Run() override { + ORT_THROW_IF_ERROR((rocm::SoftmaxWarpwiseStaticSelection>(¶ms_))); + } + + std::vector ListOps() const { + return {"SoftmaxWarpwiseStaticSelection"}; + } + + bool SelectOp(const std::string& name) { + auto status = rocm::SoftmaxWarpwiseStaticSelection>(¶ms_); + return status.IsOK() && name == "SoftmaxWarpwiseStaticSelection"; + } + + private: + using ParamsT = rocm::SoftmaxParams; + ParamsT params_{}; +}; + template class SoftmaxBlockwiseStaticSelection : public IKernelExplorer { public: @@ -176,6 +202,9 @@ void InitSoftmax(py::module m) { REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, half); REGISTER_OP_FOR_ALL_VEC_SIZE(SoftmaxBlockwise, float); + REGISTER_OP_TYPED(SoftmaxWarpwiseStaticSelection, half); + REGISTER_OP_TYPED(SoftmaxWarpwiseStaticSelection, float); + REGISTER_OP_TYPED(SoftmaxBlockwiseStaticSelection, half); REGISTER_OP_TYPED(SoftmaxBlockwiseStaticSelection, float);