Skip to content

Commit

Permalink
Refactor SkipLayerNorm and handle beta properly (#22862)
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: Liqun Fu <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and guschmue committed Dec 2, 2024
1 parent af896cb commit d9dd611
Showing 1 changed file with 78 additions and 108 deletions.
186 changes: 78 additions & 108 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,79 +96,6 @@ void ComputeJob(
}
}

void ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
const float* prepacked_skip_fp32_data,
const float* gamma_float_ptr,
const float* beta_float_ptr,
const float* bias_float_ptr,
float* output_float_ptr,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
MLFloat16* output_data,
MLFloat16* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
auto offset = task_idx * hidden_size;
const MLFloat16* p_input = input_data + offset;
MLFloat16* p_output = output_data + offset;
MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;

float mean(0.0f);
float mean_square(0.0f);
const size_t num_elems = static_cast<size_t>(hidden_size);

IAllocatorUniquePtr<float> input_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems);

IAllocatorUniquePtr<float> skip_float_uptr = nullptr;
if (prepacked_skip_fp32_data == nullptr && skip_data) {
const MLFloat16* p_skip = skip_data + (offset % skip_size);
skip_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems);
}

const float* input_float_ptr = input_float_uptr.get();
const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data : skip_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
float val = input_float_ptr[h] + skip_float_ptr[h];

if (bias_float_ptr) {
val += bias_float_ptr[h];
}

output_float_ptr[h] = val;
mean += val;
mean_square += val * val;
}

if (nullptr != p_skip_input_bias_add_output) {
MlasConvertFloatToHalfBuffer(output_float_ptr, p_skip_input_bias_add_output, num_elems);
}

mean = mean / hidden_size;
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h];
} else if (nullptr == beta_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h];
}
}

MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems);
}

void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr<float>& dest, bool& is_packed) {
if (tensor.GetElementType() == utils::ToTensorProtoElementType<MLFloat16>()) {
auto tensor_data_ptr = tensor.Data<MLFloat16>();
Expand Down Expand Up @@ -200,8 +127,8 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input<Tensor>(0);
const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3);
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(4);
const Tensor* beta = simplified ? nullptr : (prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3));
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(simplified ? 3 : 4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());
Expand Down Expand Up @@ -232,56 +159,93 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {

// For inferencing, we support one more optional output which is the sum of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData<T>();

const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_;

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

IAllocatorUniquePtr<float> output_fp32;
IAllocatorUniquePtr<float> gamma_fp32;
IAllocatorUniquePtr<float> beta_fp32;
IAllocatorUniquePtr<float> bias_fp32;

if constexpr (std::is_same_v<T, MLFloat16>) {
const size_t total_data_size = static_cast<size_t>(input->Shape().Size());

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

IAllocatorUniquePtr<float> input_fp32;
IAllocatorUniquePtr<float> output_fp32;
IAllocatorUniquePtr<float> skip_input_bias_add_output_fp32;
IAllocatorUniquePtr<float> skip_fp32;
IAllocatorUniquePtr<float> gamma_fp32;
IAllocatorUniquePtr<float> beta_fp32;
IAllocatorUniquePtr<float> bias_fp32;

const float* input_data_f = nullptr;
const float* skip_data_f = nullptr;
const float* gamma_data_f = nullptr;
const float* beta_data_f = nullptr;
const float* bias_data_f = nullptr;
float* output_data_f = nullptr;
float* skip_input_bias_add_output_data_f = nullptr;

const size_t num_elems = static_cast<size_t>(hidden_size);

output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
input_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size);
input_data_f = input_fp32.get();

output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
output_data_f = output_fp32.get();

skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get();

if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) {
if (skip_data) {
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, static_cast<size_t>(skip_size));
MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), static_cast<size_t>(skip_size));
skip_data_f = skip_fp32.get();
} else if (prepacked_skip_fp32_data_) {
skip_data_f = prepacked_skip_fp32_data_.get();
}

if (gamma_data) {
gamma_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems);
gamma_data_f = gamma_fp32.get();
} else if (prepacked_gamma_fp32_data_) {
gamma_data_f = prepacked_gamma_fp32_data_.get();
}

if (prepacked_beta_fp32_data_ == nullptr && beta_data) {
if (beta_data) {
beta_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems);
beta_data_f = beta_fp32.get();
} else if (prepacked_beta_fp32_data_) {
beta_data_f = prepacked_beta_fp32_data_.get();
}

if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
if (bias_data) {
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
bias_data_f = bias_fp32.get();
} else if (prepacked_bias_fp32_data_) {
bias_data_f = prepacked_bias_fp32_data_.get();
}
}

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
if constexpr (std::is_same_v<T, MLFloat16>) {
ComputeJob(input_data, skip_data,
prepacked_skip_fp32_data_.get(),
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
output_fp32.get(),
task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data, alloc);
} else {
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
ComputeJob(input_data_f, skip_data_f, gamma_data_f, beta_data_f, bias_data_f, task_idx, hidden_size, skip_size,
epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f);
},
0);
MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size);
if (skip_input_bias_add_output_data != nullptr)
MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size);
} else {
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size,
epsilon_, simplified, output_data, skip_input_bias_add_output_data);
}
},
0);
},
0);
}

return Status::OK();
}
Expand All @@ -290,16 +254,22 @@ template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);

is_packed = false;
if (input_idx == 1) { // skip
prepacked_skip_fp32_size_ = tensor.Shape().Size();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed);
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed);
} else if (input_idx == 3) { // beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
} else if (input_idx == 3) {
if constexpr (simplified) {
// bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
} else {
// beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
}
} else if (input_idx == 4) { // bias
ORT_ENFORCE(!simplified, "SkipSimplifiedLayerNormalization should only has 4 inputs (input, skip, gamma, and beta). Got 5.");
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

Expand Down

0 comments on commit d9dd611

Please sign in to comment.