diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index 953a45e15b32e..16b3cf053b586 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -714,12 +714,12 @@ Status ComputeSoftmaxWithRawMask(cudaStream_t stream, } if (use_persistent_softmax) { - dispatch_warpwise_softmax_forward(stream, - output, - persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, - batch_size * num_heads * sequence_length); + return dispatch_warpwise_softmax_forward(stream, + output, + persistent_softmax_workspace, + all_sequence_length, + all_sequence_length, + batch_size * num_heads * sequence_length); } return CUDA_CALL(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 1a5a9ac5d97b2..f6be2179dfdbf 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -337,11 +337,11 @@ Status ProcessLogits(const OrtValue& logits, // const CudaT* X_data = is_reuse_logits_buffer ? logits_data : reinterpret_cast(next_token_logits.data()); - dispatch_blockwise_softmax_forward( + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward( cuda_stream, Y_data, X_data, vocab_size, is_reuse_logits_buffer ? padded_vocab_size : vocab_size, vocab_size, - batch_size * num_beams); + batch_size * num_beams))); #ifdef DEBUG_GENERATION dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index d82648890f94f..753aea9d38089 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -88,14 +88,14 @@ Status Sample(AllocatorPtr& allocator, #endif gsl::span& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream, - d_sorted_softmaxed_score.data(), - reinterpret_cast(d_sorted_score.data()), - parameters->vocab_size, - parameters->vocab_size, - parameters->vocab_size, - parameters->batch_size); - + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(cuda_stream, + d_sorted_softmaxed_score.data(), + reinterpret_cast(d_sorted_score.data()), + parameters->vocab_size, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size))); + #ifdef DEBUG_GENERATION dumper->Print("d_sorted_softmaxed_score_buffer", d_sorted_softmaxed_score.data(), @@ -122,13 +122,13 @@ Status Sample(AllocatorPtr& allocator, #endif gsl::span& d_softmaxed_score = sampling_state->d_softmaxed_score; - dispatch_blockwise_softmax_forward(cuda_stream, - d_softmaxed_score.data(), - reinterpret_cast(next_token_scores.data()), - parameters->vocab_size, - parameters->vocab_size, - parameters->vocab_size, - parameters->batch_size); + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(cuda_stream, + d_softmaxed_score.data(), + reinterpret_cast(next_token_scores.data()), + parameters->vocab_size, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size))); #ifdef DEBUG_GENERATION dumper->Print("d_softmaxed_score_buffer", diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h index 27ecdf253ecdb..7c99fc05ec9ee 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -513,12 +513,12 @@ Status ComputeSoftmaxWithRawMask(hipStream_t stream, } if (use_persistent_softmax) { - dispatch_warpwise_softmax_forward(stream, - output, - persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, - batch_size * num_heads * sequence_length); + return dispatch_warpwise_softmax_forward(stream, + output, + persistent_softmax_workspace, + all_sequence_length, + all_sequence_length, + batch_size * num_heads * sequence_length); } return HIP_CALL(hipPeekAtLastError()); diff --git a/onnxruntime/core/providers/cuda/math/softmax.cc b/onnxruntime/core/providers/cuda/math/softmax.cc index dc1830a192945..5047a70242a5c 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.cc +++ b/onnxruntime/core/providers/cuda/math/softmax.cc @@ -26,15 +26,12 @@ Status SoftMaxComputeHelper( auto X_data = reinterpret_cast(X); if (D <= 1024 && D * sizeof(T) <= 4096) { - dispatch_warpwise_softmax_forward, is_log_softmax>( + return dispatch_warpwise_softmax_forward, is_log_softmax>( stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); - } else { - dispatch_blockwise_softmax_forward, is_log_softmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), - gsl::narrow_cast(N)); } - - return Status::OK(); + return dispatch_blockwise_softmax_forward, is_log_softmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), + gsl::narrow_cast(N)); } #define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \ diff --git a/onnxruntime/core/providers/cuda/math/softmax.h b/onnxruntime/core/providers/cuda/math/softmax.h index b2528bb0c8855..b66ad32517458 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.h +++ b/onnxruntime/core/providers/cuda/math/softmax.h @@ -18,12 +18,12 @@ Status SoftMaxComputeHelper( int64_t axis); template -void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, - int softmax_elements, int softmax_elements_stride, int batch_count); +Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, + int softmax_elements, int softmax_elements_stride, int batch_count); template -void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, - int softmax_elements, int input_stride, int output_stride, int batch_count); +Status dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, + int softmax_elements, int input_stride, int output_stride, int batch_count); template class Softmax final : public CudaKernel { diff --git a/onnxruntime/core/providers/cuda/math/softmax_impl.cu b/onnxruntime/core/providers/cuda/math/softmax_impl.cu index dafc3a17900ac..4c097f714beb9 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_impl.cu +++ b/onnxruntime/core/providers/cuda/math/softmax_impl.cu @@ -29,9 +29,9 @@ namespace onnxruntime { namespace cuda { template -void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { +Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { if (softmax_elements == 0) { - return; + return Status::OK(); } else { int log2_elements = log2_ceil(softmax_elements); const int next_power_of_two = 1 << log2_elements; @@ -99,15 +99,16 @@ void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t* dst, const break; } } + return CUDA_CALL(cudaGetLastError()); } -#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t * dst, \ - const input_t* src, int softmax_elements, \ - int softmax_elements_stride, int batch_count); \ - template void dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t * dst, \ - const input_t* src, int softmax_elements, \ - int softmax_elements_stride, int batch_count); +#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ + template Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t * dst, \ + const input_t* src, int softmax_elements, \ + int softmax_elements_stride, int batch_count); \ + template Status dispatch_warpwise_softmax_forward(cudaStream_t stream, output_t * dst, \ + const input_t* src, int softmax_elements, \ + int softmax_elements_stride, int batch_count); SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(float, float, float) SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(half, half, float) @@ -115,8 +116,8 @@ SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(double, double, double) SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float) template -void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, int softmax_elements, - int input_stride, int output_stride, int batch_count) { +Status dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, const input_t* input, int softmax_elements, + int input_stride, int output_stride, int batch_count) { dim3 grid(batch_count); constexpr int ILP = sizeof(float4) / sizeof(input_t); dim3 block = SoftMax_getBlockSize(ILP, softmax_elements); @@ -129,13 +130,14 @@ void dispatch_blockwise_softmax_forward(cudaStream_t stream, output_t* output, c <<>>(output, const_cast(input), softmax_elements, input_stride, output_stride); } + return CUDA_CALL(cudaGetLastError()); } #define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template void dispatch_blockwise_softmax_forward( \ + template Status dispatch_blockwise_softmax_forward( \ cudaStream_t stream, output_t * output, const input_t* src, int softmax_elements, \ int input_stride, int output_stride, int batch_count); \ - template void dispatch_blockwise_softmax_forward( \ + template Status dispatch_blockwise_softmax_forward( \ cudaStream_t stream, output_t * output, const input_t* src, int softmax_elements, \ int input_stride, int output_stride, int batch_count); diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index 275c8ad3978f5..22bcaecf34f65 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -26,15 +26,12 @@ Status SoftMaxComputeHelper( auto X_data = reinterpret_cast(X); if (D <= 1024 && D * sizeof(T) <= 4096) { - dispatch_warpwise_softmax_forward, is_log_softmax>( + return dispatch_warpwise_softmax_forward, is_log_softmax>( stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N)); - } else { - dispatch_blockwise_softmax_forward, is_log_softmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), - gsl::narrow_cast(N)); } - - return Status::OK(); + return dispatch_blockwise_softmax_forward, is_log_softmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), + gsl::narrow_cast(N)); } #define SPECIALIZED_SOFTMAX_HELPER_IMPL(T) \ diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu index f5a26ef045881..d37235acfa0e1 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_impl.cu +++ b/onnxruntime/core/providers/rocm/math/softmax_impl.cu @@ -30,9 +30,9 @@ namespace onnxruntime { namespace rocm { template -void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { +Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { if (softmax_elements == 0) { - return; + return Status::OK(); } else { int log2_elements = log2_ceil(softmax_elements); const int next_power_of_two = 1 << log2_elements; @@ -88,11 +88,12 @@ void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const break; } } + return HIP_CALL(hipGetLastError()); } #define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \ -template void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \ -template void dispatch_warpwise_softmax_forward(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); +template Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); \ +template Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); SPECIALIZED_SOFTMAX_IMPL(float, float, float) SPECIALIZED_SOFTMAX_IMPL(half, half, float) @@ -100,7 +101,7 @@ SPECIALIZED_SOFTMAX_IMPL(double, double, double) SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float) template -void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements, +Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements, int input_stride, int output_stride, int batch_count) { dim3 grid(batch_count); constexpr int ILP = sizeof(float4) / sizeof(input_t); @@ -114,14 +115,15 @@ void dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, co <<>>(output, const_cast(input), softmax_elements, input_stride, output_stride); } + return HIP_CALL(hipGetLastError()); } -#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template void dispatch_blockwise_softmax_forward( \ - hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \ - int input_stride, int output_stride, int batch_count); \ - template void dispatch_blockwise_softmax_forward( \ - hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \ +#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ + template Status dispatch_blockwise_softmax_forward( \ + hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \ + int input_stride, int output_stride, int batch_count); \ + template Status dispatch_blockwise_softmax_forward( \ + hipStream_t stream, output_t * output, const input_t* src, int softmax_elements, \ int input_stride, int output_stride, int batch_count); SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)