-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ee0b632
commit d8d98f3
Showing
7 changed files
with
209 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/common/gsl.h" | ||
#include "core/providers/rocm/rocm_kernel.h" | ||
|
||
namespace onnxruntime { | ||
namespace rocm { | ||
|
||
template <typename T, bool is_log_softmax> | ||
Status SoftMaxComputeHelper( | ||
hipStream_t stream, | ||
const T* input, | ||
const TensorShape& shape, | ||
T* Y, | ||
int64_t axis, | ||
bool tuning = false); | ||
|
||
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax> | ||
Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, | ||
int softmax_elements, int softmax_elements_stride, int batch_count, bool tuning = false); | ||
|
||
template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax> | ||
Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, | ||
int softmax_elements, int input_stride, int output_stride, int batch_count, bool tuning = false); | ||
|
||
template <typename T> | ||
class Softmax final : public RocmKernel { | ||
public: | ||
Softmax(const OpKernelInfo& info) : RocmKernel{info} { | ||
const auto& node = info.node(); | ||
opset_ = node.SinceVersion(); | ||
|
||
int64_t axis; | ||
Status status = info.GetAttr<int64_t>("axis", &axis); | ||
|
||
if (status.IsOK()) { | ||
axis_ = gsl::narrow_cast<int>(axis); | ||
} else { | ||
if (opset_ < 13) { | ||
axis_ = 1; // opset-12 and below, the default axis value is 1 | ||
} else { | ||
axis_ = -1; // opset-13, the default axis value is -1 | ||
} | ||
} | ||
|
||
log_softmax_ = info.GetKernelDef().OpName() == "LogSoftmax"; | ||
|
||
// We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method | ||
// TODO: Clean up the ROCMExecutionProvider interface to avoid this | ||
rocm_ep_ = const_cast<ROCMExecutionProvider*>( | ||
static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider())); | ||
} | ||
|
||
Status ComputeInternal(OpKernelContext* context) const override; | ||
|
||
private: | ||
int64_t axis_; | ||
bool log_softmax_; | ||
int opset_; | ||
|
||
// We need to access to the ROCM EP instance to get the rocblas handle to use | ||
// for transposing(if applicable) | ||
ROCMExecutionProvider* rocm_ep_; | ||
}; | ||
|
||
} // namespace rocm | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.