Skip to content

Commit

Permalink
add tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 10, 2023
1 parent 0b6c029 commit 16f1b4c
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 102 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ set(provider_excluded_files
"math/softmax_common.cc"
"math/softmax_common.h"
"math/softmax.cc"
"math/softmax.h"
"nn/conv.cc"
"nn/conv.h"
"nn/conv_transpose.cc"
Expand Down
30 changes: 18 additions & 12 deletions onnxruntime/core/providers/rocm/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
namespace onnxruntime {
namespace rocm {

template <typename T, bool is_log_softmax>
template <typename T, bool IsLogSoftmax>
Status SoftMaxComputeHelper(
hipStream_t stream,
const T* X,
const TensorShape& input_shape,
T* Y,
int64_t axis) {
int64_t axis,
RocmTuningContext* tuning_ctx) {
typedef typename ToHipType<T>::MappedType HipT;

int64_t N = input_shape.SizeToDimension(axis);
Expand All @@ -26,17 +27,20 @@ Status SoftMaxComputeHelper(
auto X_data = reinterpret_cast<const HipT*>(X);

if (D <= 1024 && D * sizeof(T) <= 4096) {
return dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
return dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, IsLogSoftmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N), tuning_ctx);
}
return dispatch_blockwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
return dispatch_blockwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, IsLogSoftmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N), tuning_ctx);
}

#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \
template Status SoftMaxComputeHelper<T, false>(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, int64_t axis); \
template Status SoftMaxComputeHelper<T, true>(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, int64_t axis);
#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \
template Status SoftMaxComputeHelper<T, false>(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, \
int64_t axis, RocmTuningContext* tuning_ctx); \
template Status SoftMaxComputeHelper<T, true>(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, \
int64_t axis, RocmTuningContext* tuning_ctx);

SPECIALIZED_SOFTMAX_HELPER_IMPL(float)
// MIOpen double data type not supported
Expand Down Expand Up @@ -173,11 +177,13 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
if (log_softmax_) {
status = SoftMaxComputeHelper<T, true>(Stream(ctx), X_data, *compute_input_shape, Y_data,
is_transpose_required ? static_cast<int64_t>(rank) - 1
: static_cast<int64_t>(axis));
: static_cast<int64_t>(axis),
GetTuningContext());
} else {
status = SoftMaxComputeHelper<T, false>(Stream(ctx), X_data, *compute_input_shape, Y_data,
is_transpose_required ? static_cast<int64_t>(rank) - 1
: static_cast<int64_t>(axis));
: static_cast<int64_t>(axis),
GetTuningContext());
}

if (!status.IsOK())
Expand Down
74 changes: 74 additions & 0 deletions onnxruntime/core/providers/rocm/math/softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/gsl.h"
#include "core/providers/rocm/rocm_kernel.h"

namespace onnxruntime {
namespace rocm {

using tunable::RocmTuningContext;

template <typename T, bool IsLogSoftmax>
Status SoftMaxComputeHelper(
hipStream_t stream,
const T* input,
const TensorShape& shape,
T* Y,
int64_t axis,
RocmTuningContext* tuning_ctx = nullptr);

template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements,
int softmax_elements_stride, int batch_count,
RocmTuningContext* tuning_ctx = nullptr);

template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements,
int input_stride, int output_stride, int batch_count,
RocmTuningContext* tuning_ctx = nullptr);

template <typename T>
class Softmax final : public RocmKernel {
public:
Softmax(const OpKernelInfo& info) : RocmKernel{info} {
const auto& node = info.node();
opset_ = node.SinceVersion();

int64_t axis;
Status status = info.GetAttr<int64_t>("axis", &axis);

if (status.IsOK()) {
axis_ = gsl::narrow_cast<int>(axis);
} else {
if (opset_ < 13) {
axis_ = 1; // opset-12 and below, the default axis value is 1
} else {
axis_ = -1; // opset-13, the default axis value is -1
}
}

log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax";

// We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method
// TODO: Clean up the ROCMExecutionProvider interface to avoid this
rocm_ep_ = const_cast<ROCMExecutionProvider*>(
static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider()));
}

Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t axis_;
bool log_softmax_;
int opset_;

// We need to access to the ROCM EP instance to get the rocblas handle to use
// for transposing(if applicable)
ROCMExecutionProvider* rocm_ep_;
};

} // namespace rocm
} // namespace onnxruntime
106 changes: 34 additions & 72 deletions onnxruntime/core/providers/rocm/math/softmax_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,99 +20,61 @@
#include "hip/hip_runtime.h"

#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/math/softmax_warpwise_impl.cuh"
#include "core/providers/rocm/math/softmax_blockwise_impl.cuh"
#include "core/providers/rocm/math/softmax.h"
#include "core/providers/rocm/math/softmax_common.h"
#include "core/providers/rocm/math/softmax_tunable_op.cuh"

#include <limits>

namespace onnxruntime {
namespace rocm {

template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return Status::OK();
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;

// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST;

// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = 1;
// use 256 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 256;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \
case L2E: \
softmax_warp_forward<InputT, OutputT, AccT, L2E, IsLogSoftmax> \
<<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, \
softmax_elements_stride, softmax_elements); \
break;
LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024
default:
break;
}
Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements,
int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx) {
SoftmaxParams<InputT, OutputT> params(tuning_ctx, stream, dst, src, softmax_elements, softmax_elements_stride,
softmax_elements_stride, batch_count, IsLogSoftmax);
if (tuning_ctx->IsTunableOpEnabled()) {
static SoftmaxTunableOp<InputT, OutputT, AccT> op;
return op(&params);
}
return HIP_CALL(hipGetLastError());
return SoftmaxWarpwiseStaticSelection<InputT, OutputT, AccT>(&params);
}

#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT) \
template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, false>( \
hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, true>( \
hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \
int softmax_elements_stride, int batch_count);
#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT) \
template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, false>( \
hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \
int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx); \
template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, true>( \
hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \
int softmax_elements_stride, int batch_count, RocmTuningContext* tuning_ctx);

SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(InputT);
dim3 block = SoftMax_getBlockSize(ILP, softmax_elements);
if (IsLogSoftmax) {
softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input),
softmax_elements, input_stride, output_stride);
} else {
softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input),
softmax_elements, input_stride, output_stride);
Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output,
const InputT* input, int softmax_elements,
int input_stride, int output_stride,
int batch_count, RocmTuningContext* tuning_ctx) {
SoftmaxParams<InputT, OutputT> params(tuning_ctx, stream, output, input, softmax_elements, input_stride,
output_stride, batch_count, IsLogSoftmax);
if (tuning_ctx->IsTunableOpEnabled()) {
static SoftmaxTunableOp<InputT, OutputT, AccT> op;
return op(&params);
}
return HIP_CALL(hipGetLastError());
return SoftmaxBlockwiseStaticSelection<InputT, OutputT, AccT>(&params);
}

#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT) \
template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, false>( \
hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, true>( \
hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count);
#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT) \
template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, false>( \
hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count, RocmTuningContext* tuning_ctx); \
template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, true>( \
hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count, RocmTuningContext* tuning_ctx);

SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, half, float)
Expand Down
62 changes: 61 additions & 1 deletion onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,55 @@ Status SoftmaxBlockwiseOp(const SoftmaxParams<InputT, OutputT>* params) {
return HIP_CALL(hipGetLastError());
}

template <typename InputT, typename OutputT, typename AccT>
Status SoftmaxWarpwiseStaticSelection(const SoftmaxParams<InputT, OutputT>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!(params->input_stride <= 1024 && params->input_stride * sizeof(InputT) <= 4096));
if (params->softmax_elements == 0) {
return Status::OK();
} else {
int log2_elements = log2_ceil(params->softmax_elements);
const int next_power_of_two = 1 << log2_elements;

// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST;

// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = 1;
// use 256 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 256;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (params->batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \
case L2E: \
softmax_warp_forward<InputT, OutputT, AccT, L2E> \
<<<dim3(blocks), dim3(threads), 0, params->stream>>>(params->output, params->input, params->batch_count, \
params->input_stride, params->softmax_elements, \
params->is_log_softmax); \
break;
LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024
default:
break;
}
}
return HIP_CALL(hipGetLastError());
}

template <typename InputT, typename OutputT, typename AccT>
Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<InputT, OutputT>* params) {
dim3 grid(params->batch_count);
Expand All @@ -53,9 +102,20 @@ Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<InputT, OutputT>* par
}

template <typename InputT, typename OutputT, typename AccT>
class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp<SoftmaxParams<InputT, OutputT>> {
Status SoftmaxStaticSelection(const SoftmaxParams<InputT, OutputT>* params) {
auto status = SoftmaxWarpwiseStaticSelection<InputT, OutputT, AccT>(params);
if (!status.IsOK()) {
status = SoftmaxBlockwiseStaticSelection<InputT, OutputT, AccT>(params);
}
return status;
}

template <typename InputT, typename OutputT, typename AccT>
class SoftmaxTunableOp : public tunable::TunableOp<SoftmaxParams<InputT, OutputT>> {
public:
SoftmaxTunableOp() {
this->RegisterOp(SoftmaxStaticSelection<InputT, OutputT, AccT>);
this->RegisterOp(SoftmaxWarpwiseStaticSelection<InputT, OutputT, AccT>);
this->RegisterOp(SoftmaxBlockwiseStaticSelection<InputT, OutputT, AccT>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 1>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 2>);
Expand Down
Loading

0 comments on commit 16f1b4c

Please sign in to comment.