Skip to content

Commit

Permalink
add tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 9, 2023
1 parent ee0b632 commit d8d98f3
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 87 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ set(provider_excluded_files
"math/softmax_common.cc"
"math/softmax_common.h"
"math/softmax.cc"
"math/softmax.h"
"nn/conv.cc"
"nn/conv.h"
"nn/conv_transpose.cc"
Expand Down
21 changes: 12 additions & 9 deletions onnxruntime/core/providers/rocm/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ Status SoftMaxComputeHelper(
const T* X,
const TensorShape& input_shape,
T* Y,
int64_t axis) {
int64_t axis,
bool tuning) {
typedef typename ToHipType<T>::MappedType HipT;

int64_t N = input_shape.SizeToDimension(axis);
Expand All @@ -27,16 +28,16 @@ Status SoftMaxComputeHelper(

if (D <= 1024 && D * sizeof(T) <= 4096) {
return dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N), tuning);
}
return dispatch_blockwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N), tuning);
}

#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \
template Status SoftMaxComputeHelper<T, false>(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, int64_t axis); \
template Status SoftMaxComputeHelper<T, true>(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, int64_t axis);
#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \
template Status SoftMaxComputeHelper<T, false>(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, int64_t axis, bool tuning); \
template Status SoftMaxComputeHelper<T, true>(hipStream_t stream, const T* input, const TensorShape& shape, T* Y, int64_t axis, bool tuning);

SPECIALIZED_SOFTMAX_HELPER_IMPL(float)
// MIOpen double data type not supported
Expand Down Expand Up @@ -173,11 +174,13 @@ Status Softmax<T>::ComputeInternal(OpKernelContext* ctx) const {
if (log_softmax_) {
status = SoftMaxComputeHelper<T, true>(Stream(ctx), X_data, *compute_input_shape, Y_data,
is_transpose_required ? static_cast<int64_t>(rank) - 1
: static_cast<int64_t>(axis));
: static_cast<int64_t>(axis),
IsTunableOpEnabled());
} else {
status = SoftMaxComputeHelper<T, false>(Stream(ctx), X_data, *compute_input_shape, Y_data,
is_transpose_required ? static_cast<int64_t>(rank) - 1
: static_cast<int64_t>(axis));
: static_cast<int64_t>(axis),
IsTunableOpEnabled());
}

if (!status.IsOK())
Expand Down
70 changes: 70 additions & 0 deletions onnxruntime/core/providers/rocm/math/softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/gsl.h"
#include "core/providers/rocm/rocm_kernel.h"

namespace onnxruntime {
namespace rocm {

template <typename T, bool is_log_softmax>
Status SoftMaxComputeHelper(
hipStream_t stream,
const T* input,
const TensorShape& shape,
T* Y,
int64_t axis,
bool tuning = false);

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src,
int softmax_elements, int softmax_elements_stride, int batch_count, bool tuning = false);

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input,
int softmax_elements, int input_stride, int output_stride, int batch_count, bool tuning = false);

template <typename T>
class Softmax final : public RocmKernel {
public:
Softmax(const OpKernelInfo& info) : RocmKernel{info} {
const auto& node = info.node();
opset_ = node.SinceVersion();

int64_t axis;
Status status = info.GetAttr<int64_t>("axis", &axis);

if (status.IsOK()) {
axis_ = gsl::narrow_cast<int>(axis);
} else {
if (opset_ < 13) {
axis_ = 1; // opset-12 and below, the default axis value is 1
} else {
axis_ = -1; // opset-13, the default axis value is -1
}
}

log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax";

// We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method
// TODO: Clean up the ROCMExecutionProvider interface to avoid this
rocm_ep_ = const_cast<ROCMExecutionProvider*>(
static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider()));
}

Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t axis_;
bool log_softmax_;
int opset_;

// We need to access to the ROCM EP instance to get the rocblas handle to use
// for transposing(if applicable)
ROCMExecutionProvider* rocm_ep_;
};

} // namespace rocm
} // namespace onnxruntime
83 changes: 22 additions & 61 deletions onnxruntime/core/providers/rocm/math/softmax_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,68 +20,34 @@
#include "hip/hip_runtime.h"

#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/math/softmax_warpwise_impl.cuh"
#include "core/providers/rocm/math/softmax_blockwise_impl.cuh"
#include "core/providers/rocm/math/softmax.h"
#include "core/providers/rocm/math/softmax_common.h"
#include "core/providers/rocm/math/softmax_tunable_op.cuh"

#include <limits>

namespace onnxruntime {
namespace rocm {

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return Status::OK();
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;

// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST;

// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = 1;
// use 256 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 256;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \
case L2E: \
softmax_warp_forward<input_t, output_t, acc_t, L2E, is_log_softmax> \
<<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, \
softmax_elements_stride, softmax_elements); \
break;
LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024
default:
break;
}
Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count, bool tuning) {
SoftmaxParams<input_t, output_t> params(stream, dst, src, softmax_elements, softmax_elements_stride,
softmax_elements_stride, batch_count, is_log_softmax);
if (tuning) {
static SoftmaxTunableOp<input_t, output_t, acc_t> op;
op.EnableTuning();
return op(&params);
}
return HIP_CALL(hipGetLastError());
return SoftmaxWarpwiseStaticSelection<input_t, output_t, acc_t>(&params);
}

#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
int softmax_elements_stride, int batch_count, bool tuning); \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count);
int softmax_elements_stride, int batch_count, bool tuning);

SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
Expand All @@ -90,29 +56,24 @@ SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(input_t);
dim3 block = SoftMax_getBlockSize(ILP, softmax_elements);
if (is_log_softmax) {
softmax_block_forward<ILP, input_t, acc_t, output_t, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
} else {
softmax_block_forward<ILP, input_t, acc_t, output_t, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
int input_stride, int output_stride, int batch_count, bool tuning) {
SoftmaxParams<input_t, output_t> params(stream, output, input, softmax_elements, input_stride,
output_stride, batch_count, is_log_softmax);
if (tuning) {
static SoftmaxTunableOp<input_t, output_t, acc_t> op;
op.EnableTuning();
return op(&params);
}
return HIP_CALL(hipGetLastError());
return SoftmaxBlockwiseStaticSelection<input_t, output_t, acc_t>(&params);
}

#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * output, const input_t* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
int input_stride, int output_stride, int batch_count, bool tuning); \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * output, const input_t* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count);
int input_stride, int output_stride, int batch_count, bool tuning);

SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, half, float)
Expand Down
59 changes: 59 additions & 0 deletions onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,54 @@ Status SoftmaxBlockwiseOp(const SoftmaxParams<input_t, output_t>* params) {
return HIP_CALL(hipGetLastError());
}

template <typename input_t, typename output_t, typename acc_t>
Status SoftmaxWarpwiseStaticSelection(const SoftmaxParams<input_t, output_t>* params) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!(params->input_stride <= 1024 && params->input_stride * sizeof(input_t) <= 4096));
if (params->softmax_elements == 0) {
return Status::OK();
} else {
int log2_elements = log2_ceil(params->softmax_elements);
const int next_power_of_two = 1 << log2_elements;

// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < GPU_WARP_SIZE_HOST) ? next_power_of_two : GPU_WARP_SIZE_HOST;

// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = 1;
// use 256 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 256;

int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (params->batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \
case L2E: \
softmax_warp_forward<input_t, output_t, acc_t, L2E> \
<<<dim3(blocks), dim3(threads), 0, params->stream>>>(params->output, params->input, params->batch_count, \
params->input_stride, params->softmax_elements, params->is_log_softmax); \
break;
LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024
default:
break;
}
}
return HIP_CALL(hipGetLastError());
}

template <typename input_t, typename output_t, typename acc_t>
Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<input_t, output_t>* params) {
dim3 grid(params->batch_count);
Expand All @@ -52,10 +100,21 @@ Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<input_t, output_t>* p
return HIP_CALL(hipGetLastError());
}

template <typename input_t, typename output_t, typename acc_t>
Status SoftmaxStaticSelection(const SoftmaxParams<input_t, output_t>* params) {
auto status = SoftmaxWarpwiseStaticSelection<input_t, output_t, acc_t>(params);
if (!status.IsOK()) {
status = SoftmaxBlockwiseStaticSelection<input_t, output_t, acc_t>(params);
}
return status;
}

template <typename input_t, typename output_t, typename acc_t>
class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp<SoftmaxParams<input_t, output_t>> {
public:
SoftmaxTunableOp() {
this->RegisterOp(SoftmaxStaticSelection<input_t, output_t, acc_t>);
this->RegisterOp(SoftmaxWarpwiseStaticSelection<input_t, output_t, acc_t>);
this->RegisterOp(SoftmaxBlockwiseStaticSelection<input_t, output_t, acc_t>);
this->RegisterOp(SoftmaxBlockwiseOp<input_t, output_t, acc_t, 1>);
this->RegisterOp(SoftmaxBlockwiseOp<input_t, output_t, acc_t, 2>);
Expand Down
33 changes: 16 additions & 17 deletions onnxruntime/core/providers/rocm/math/softmax_warpwise_impl.cuh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// The code below is mostly copied from Pytorch PersistentSoftmax.cuh

Expand Down Expand Up @@ -56,7 +56,6 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
}
}


// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension.
// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024.
// The template arguments have the following meaning:
Expand All @@ -74,8 +73,8 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
// input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
// input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor.

template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax>
__global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batch_size, int stride, int element_count) {
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batch_size, int stride, int element_count, bool is_log_softmax) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < GPU_WARP_SIZE) ? next_power_of_two : GPU_WARP_SIZE;
Expand Down
Loading

0 comments on commit d8d98f3

Please sign in to comment.