Skip to content

Commit

Permalink
Change the return type of softmax function to Status (#14559)
Browse files Browse the repository at this point in the history
### Description
Change the return type of Softmax
function(`dispatch_warpwise_softmax_forward `and
`dispatch_blockwise_softmax_forward`) from `void ` to `Status`.

### Motivation and Context
Softmax function will call TunableOp which return Status. It's necessary
to pass the `Status` from inner function to outer function.
  • Loading branch information
PeixuanZuo authored Feb 6, 2023
1 parent 3d75187 commit 4bb95d7
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 71 deletions.
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,12 +714,12 @@ Status ComputeSoftmaxWithRawMask(cudaStream_t stream,
}

if (use_persistent_softmax) {
dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
output,
persistent_softmax_workspace,
all_sequence_length,
all_sequence_length,
batch_size * num_heads * sequence_length);
return dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
output,
persistent_softmax_workspace,
all_sequence_length,
all_sequence_length,
batch_size * num_heads * sequence_length);
}

return CUDA_CALL(cudaGetLastError());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,11 @@ Status ProcessLogits(const OrtValue& logits, //

const CudaT* X_data = is_reuse_logits_buffer ? logits_data : reinterpret_cast<const CudaT*>(next_token_logits.data());

dispatch_blockwise_softmax_forward<CudaT, float, float, true>(
ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward<CudaT, float, float, true>(
cuda_stream, Y_data, X_data, vocab_size,
is_reuse_logits_buffer ? padded_vocab_size : vocab_size,
vocab_size,
batch_size * num_beams);
batch_size * num_beams)));

#ifdef DEBUG_GENERATION
dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size);
Expand Down
30 changes: 15 additions & 15 deletions onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ Status Sample(AllocatorPtr& allocator,
#endif

gsl::span<float>& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score;
dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_sorted_softmaxed_score.data(),
reinterpret_cast<CudaT*>(d_sorted_score.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size);

ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_sorted_softmaxed_score.data(),
reinterpret_cast<CudaT*>(d_sorted_score.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size)));
#ifdef DEBUG_GENERATION
dumper->Print("d_sorted_softmaxed_score_buffer",
d_sorted_softmaxed_score.data(),
Expand All @@ -122,13 +122,13 @@ Status Sample(AllocatorPtr& allocator,
#endif

gsl::span<float>& d_softmaxed_score = sampling_state->d_softmaxed_score;
dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_softmaxed_score.data(),
reinterpret_cast<CudaT*>(next_token_scores.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size);
ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_softmaxed_score.data(),
reinterpret_cast<CudaT*>(next_token_scores.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size)));

#ifdef DEBUG_GENERATION
dumper->Print("d_softmaxed_score_buffer",
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/rocm/bert/attention_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,12 +513,12 @@ Status ComputeSoftmaxWithRawMask(hipStream_t stream,
}

if (use_persistent_softmax) {
dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
output,
persistent_softmax_workspace,
all_sequence_length,
all_sequence_length,
batch_size * num_heads * sequence_length);
return dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
output,
persistent_softmax_workspace,
all_sequence_length,
all_sequence_length,
batch_size * num_heads * sequence_length);
}

return HIP_CALL(hipPeekAtLastError());
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/core/providers/cuda/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@ Status SoftMaxComputeHelper(
auto X_data = reinterpret_cast<const CudaT*>(X);

if (D <= 1024 && D * sizeof(T) <= 4096) {
dispatch_warpwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>(
return dispatch_warpwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
} else {
dispatch_blockwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
}

return Status::OK();
return dispatch_blockwise_softmax_forward<CudaT, CudaT, AccumulationType_t<CudaT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
}

#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cuda/math/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ Status SoftMaxComputeHelper(
int64_t axis);

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src,
int softmax_elements, int softmax_elements_stride, int batch_count);
Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src,
int softmax_elements, int softmax_elements_stride, int batch_count);

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input,
int softmax_elements, int input_stride, int output_stride, int batch_count);
Status dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input,
int softmax_elements, int input_stride, int output_stride, int batch_count);

template <typename T>
class Softmax final : public CudaKernel {
Expand Down
28 changes: 15 additions & 13 deletions onnxruntime/core/providers/cuda/math/softmax_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ namespace onnxruntime {
namespace cuda {

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return;
return Status::OK();
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
Expand Down Expand Up @@ -99,24 +99,25 @@ void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const
break;
}
}
return CUDA_CALL(cudaGetLastError());
}

#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(cudaStream_t stream, output_t * dst, \
const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(cudaStream_t stream, output_t * dst, \
const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count);
#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(cudaStream_t stream, output_t * dst, \
const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(cudaStream_t stream, output_t * dst, \
const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count);

SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
Status dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(input_t);
dim3 block = SoftMax_getBlockSize(ILP, softmax_elements);
Expand All @@ -129,13 +130,14 @@ void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, c
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
}
return CUDA_CALL(cudaGetLastError());
}

#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
cudaStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
cudaStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count);

Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/core/providers/rocm/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@ Status SoftMaxComputeHelper(
auto X_data = reinterpret_cast<const HipT*>(X);

if (D <= 1024 && D * sizeof(T) <= 4096) {
dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
return dispatch_warpwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(N));
} else {
dispatch_blockwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
}

return Status::OK();
return dispatch_blockwise_softmax_forward<HipT, HipT, AccumulationType_t<HipT>, is_log_softmax>(
stream, Y_data, X_data, gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D), gsl::narrow_cast<int>(D),
gsl::narrow_cast<int>(N));
}

#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \
Expand Down
24 changes: 13 additions & 11 deletions onnxruntime/core/providers/rocm/math/softmax_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace onnxruntime {
namespace rocm {

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return;
return Status::OK();
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
Expand Down Expand Up @@ -88,19 +88,20 @@ void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const
break;
}
}
return HIP_CALL(hipGetLastError());
}

#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \
template void dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count);
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count);

SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(input_t);
Expand All @@ -114,14 +115,15 @@ void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, co
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
}
return HIP_CALL(hipGetLastError());
}

#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template void dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \
int input_stride, int output_stride, int batch_count);

SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
Expand Down

0 comments on commit 4bb95d7

Please sign in to comment.