Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu committed Nov 16, 2024
1 parent ada93ea commit e73eaf4
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_;

if constexpr (std::is_same_v<T, MLFloat16>) {
if (skip == nullptr) {
std::cout << "missing skip";
}
const int64_t total_data_size = input->Shape().Size();

AllocatorPtr alloc;
Expand All @@ -188,19 +185,19 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {

const size_t num_elems = static_cast<size_t>(hidden_size);

input_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
input_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(total_data_size));
MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size);
input_data_f = input_fp32.get();

output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(total_data_size));
output_data_f = output_fp32.get();

skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(total_data_size));
skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get();

if (skip_data) {
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, skip_size);
MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), skip_size);
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(skip_size));
MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), static_cast<size_t>(skip_size));
skip_data_f = skip_fp32.get();
} else if (prepacked_skip_fp32_data_) {
skip_data_f = prepacked_skip_fp32_data_.get();
Expand Down Expand Up @@ -237,9 +234,9 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f);
},
0);
MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size);
MlasConvertFloatToHalfBuffer(output_data_f, output_data, static_cast<size_t>(total_data_size));
if (skip_input_bias_add_output_data != nullptr)
MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size);
MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, static_cast<size_t>(total_data_size));
} else {
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
Expand Down

0 comments on commit e73eaf4

Please sign in to comment.