Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SkipLayerNorm and handle beta properly #22862

Merged
merged 8 commits into from
Nov 17, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 84 additions & 107 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,79 +96,6 @@ void ComputeJob(
}
}

void ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
const float* prepacked_skip_fp32_data,
const float* gamma_float_ptr,
const float* beta_float_ptr,
const float* bias_float_ptr,
float* output_float_ptr,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
MLFloat16* output_data,
MLFloat16* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
auto offset = task_idx * hidden_size;
const MLFloat16* p_input = input_data + offset;
MLFloat16* p_output = output_data + offset;
MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;

float mean(0.0f);
float mean_square(0.0f);
const size_t num_elems = static_cast<size_t>(hidden_size);

IAllocatorUniquePtr<float> input_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems);

IAllocatorUniquePtr<float> skip_float_uptr = nullptr;
if (prepacked_skip_fp32_data == nullptr && skip_data) {
const MLFloat16* p_skip = skip_data + (offset % skip_size);
skip_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems);
}

const float* input_float_ptr = input_float_uptr.get();
const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data : skip_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
float val = input_float_ptr[h] + skip_float_ptr[h];

if (bias_float_ptr) {
val += bias_float_ptr[h];
}

output_float_ptr[h] = val;
mean += val;
mean_square += val * val;
}

if (nullptr != p_skip_input_bias_add_output) {
MlasConvertFloatToHalfBuffer(output_float_ptr, p_skip_input_bias_add_output, num_elems);
}

mean = mean / hidden_size;
if (simplified) {
mean_square = sqrt(mean_square / hidden_size + epsilon);
} else {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h];
} else if (nullptr == beta_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h];
}
}

MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems);
}

void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr<float>& dest, bool& is_packed) {
if (tensor.GetElementType() == utils::ToTensorProtoElementType<MLFloat16>()) {
auto tensor_data_ptr = tensor.Data<MLFloat16>();
Expand Down Expand Up @@ -200,8 +127,8 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input<Tensor>(0);
const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3);
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(4);
const Tensor* beta = simplified ? nullptr : (prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3));
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(simplified ? 3 : 4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());
Expand Down Expand Up @@ -232,56 +159,96 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {

// For inferencing, we support one more optional output which is the sum of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData<T>();

const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_;

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

IAllocatorUniquePtr<float> output_fp32;
IAllocatorUniquePtr<float> gamma_fp32;
IAllocatorUniquePtr<float> beta_fp32;
IAllocatorUniquePtr<float> bias_fp32;

if constexpr (std::is_same_v<T, MLFloat16>) {
if (skip == nullptr) {
std::cout << "missing skip";
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
}
const int64_t total_data_size = input->Shape().Size();

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

IAllocatorUniquePtr<float> input_fp32;
IAllocatorUniquePtr<float> output_fp32;
IAllocatorUniquePtr<float> skip_input_bias_add_output_fp32;
IAllocatorUniquePtr<float> skip_fp32;
IAllocatorUniquePtr<float> gamma_fp32;
IAllocatorUniquePtr<float> beta_fp32;
IAllocatorUniquePtr<float> bias_fp32;

const float* input_data_f = nullptr;
const float* skip_data_f = nullptr;
const float* gamma_data_f = nullptr;
const float* beta_data_f = nullptr;
const float* bias_data_f = nullptr;
float* output_data_f = nullptr;
float* skip_input_bias_add_output_data_f = nullptr;

const size_t num_elems = static_cast<size_t>(hidden_size);

output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
input_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size);
input_data_f = input_fp32.get();

output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
output_data_f = output_fp32.get();

skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, total_data_size);
skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get();

if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) {
if (skip_data) {
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, skip_size);
MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), skip_size);
skip_data_f = skip_fp32.get();
} else if (prepacked_skip_fp32_data_) {
skip_data_f = prepacked_skip_fp32_data_.get();
}

if (gamma_data) {
gamma_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems);
gamma_data_f = gamma_fp32.get();
} else if (prepacked_gamma_fp32_data_) {
gamma_data_f = prepacked_gamma_fp32_data_.get();
}

if (prepacked_beta_fp32_data_ == nullptr && beta_data) {
if (beta_data) {
beta_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems);
beta_data_f = beta_fp32.get();
} else if (prepacked_beta_fp32_data_) {
beta_data_f = prepacked_beta_fp32_data_.get();
}

if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
if (bias_data) {
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
bias_data_f = bias_fp32.get();
} else if (prepacked_bias_fp32_data_) {
bias_data_f = prepacked_bias_fp32_data_.get();
}
}

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
if constexpr (std::is_same_v<T, MLFloat16>) {
ComputeJob(input_data, skip_data,
prepacked_skip_fp32_data_.get(),
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
output_fp32.get(),
task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data, alloc);
} else {
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
ComputeJob(input_data_f, skip_data_f, gamma_data_f, beta_data_f, bias_data_f, task_idx, hidden_size, skip_size,
epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f);
},
0);
MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size);
if (skip_input_bias_add_output_data != nullptr)
MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size);
} else {
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size,
epsilon_, simplified, output_data, skip_input_bias_add_output_data);
}
},
0);
},
0);
}

return Status::OK();
}
Expand All @@ -290,16 +257,26 @@ template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);
ORT_UNUSED_PARAMETER(tensor);
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
ORT_UNUSED_PARAMETER(input_idx);
ORT_UNUSED_PARAMETER(alloc);

is_packed = false;
if (input_idx == 1) { // skip
prepacked_skip_fp32_size_ = tensor.Shape().Size();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed);
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed);
} else if (input_idx == 3) { // beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
} else if (input_idx == 3) {
if (simplified) {
// bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
} else {
// beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
}
} else if (input_idx == 4) { // bias
ORT_ENFORCE(!simplified, "SkipSimplifiedLayerNormalization should only has 4 inputs (input, skip, gamma, and beta). Got 5.");
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

Expand Down
Loading