Skip to content

Commit

Permalink
add TuningContext
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 10, 2023
1 parent 10a7350 commit 0b6c029
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/rocm/math/softmax_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ namespace onnxruntime {
namespace rocm {

template <typename InputT, typename OutputT>
struct SoftmaxParams : onnxruntime::rocm::tunable::OpParams {
SoftmaxParams(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements,
int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: OpParams(stream), output(output), input(input), softmax_elements(softmax_elements), input_stride(input_stride), output_stride(output_stride), batch_count(batch_count), is_log_softmax(is_log_softmax) {}
struct SoftmaxParams : tunable::OpParams {
SoftmaxParams(tunable::RocmTuningContext* tuning_ctx, hipStream_t stream, OutputT* output, const InputT* input,
int softmax_elements, int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: OpParams(tuning_ctx, stream), output(output), input(input), softmax_elements(softmax_elements), input_stride(input_stride), output_stride(output_stride), batch_count(batch_count), is_log_softmax(is_log_softmax) {}

std::string Signature() const override {
std::string sig = std::to_string(batch_count) + "_" + std::to_string(softmax_elements);
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/python/tools/kernel_explorer/kernels/rocm/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SoftmaxBlockwise : public IKernelExplorer {
public:
SoftmaxBlockwise(DeviceArray& output, DeviceArray& input, int softmax_elements,
int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: params_(this->Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {
type_string_ = "SoftmaxBlockwise_" + std::to_string(VecSize);
}
Expand Down Expand Up @@ -55,7 +55,7 @@ class SoftmaxBlockwiseStaticSelection : public IKernelExplorer {
public:
SoftmaxBlockwiseStaticSelection(DeviceArray& output, DeviceArray& input, int softmax_elements,
int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: params_(this->Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {}

void Run() override {
Expand All @@ -80,9 +80,9 @@ class SoftmaxTunable : public IKernelExplorer {
public:
SoftmaxTunable(DeviceArray& output, DeviceArray& input, int softmax_elements,
int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: params_(this->Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {
op_.EnableTuning();
params_.TuningContext()->EnableTunableOp();
}

void Run() override {
Expand All @@ -109,7 +109,7 @@ class CKSoftmax : public IKernelExplorer {
public:
CKSoftmax(DeviceArray& output, DeviceArray& input, int softmax_elements,
int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: params_(this->Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
softmax_elements, input_stride, output_stride, batch_count, is_log_softmax) {
for (auto&& [type_string, op] : rocm::GetCKSoftmaxTypeStringAndOps<T, T, rocm::AccumulationType_t<T>>()) {
type_strings_.emplace_back(std::move(type_string));
Expand Down

0 comments on commit 0b6c029

Please sign in to comment.