Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper: Fix decoder inputs for static pipeline #1469

Merged
merged 5 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 40 additions & 66 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,15 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con
}
}

void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids,
const int64_t pad_token) {
void set_decoder_input_ids(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids) {
auto input_ids_tensor = decoder.get_tensor("input_ids");
auto attention_mask_tensor = decoder.get_tensor("attention_mask");

const size_t seq_length = input_ids_tensor.get_shape()[1];

OPENVINO_ASSERT(seq_length >= init_ids.size());

// pad right
// input_ids [token, token, token, pad_token]
// attention_mask [1, 1, 1, 0]
auto input_ids_data = input_ids_tensor.data<int32_t>();
std::copy(init_ids.begin(), init_ids.end(), input_ids_data);
std::fill(input_ids_data + init_ids.size(),
input_ids_data + input_ids_tensor.get_size(),
static_cast<int32_t>(pad_token));

auto attention_mask_data = attention_mask_tensor.data<ov::float16>();
std::fill_n(attention_mask_data, init_ids.size(), 1u);
std::fill(attention_mask_data + init_ids.size(), attention_mask_data + attention_mask_tensor.get_size(), 0u);
}

int64_t decode(ov::Tensor& encoder_hidden_state,
Expand All @@ -154,7 +141,7 @@ int64_t decode(ov::Tensor& encoder_hidden_state,
const bool return_timestamps = false) {
// NB: Fill decoder inputs
encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states"));
set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id);
set_decoder_input_ids(decoder, init_ids);

ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics);

Expand Down Expand Up @@ -210,13 +197,13 @@ void zero_past_key_values(ov::InferRequest& request) {
}
}

void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder) {
void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder, const size_t init_ids_size) {
// NB: Prepare attetion mask to be in a format [0, 0, 0, 1, 1, 1, 1, ..., 0, 1]
// Mask should be inverted for decoder_with_past
auto attention_mask = decoder_with_past.get_tensor("attention_mask");
auto* attention_mask_ptr = attention_mask.data<ov::float16>();
std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0);
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1);
std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0);
std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1);
TolyaTalamanov marked this conversation as resolved.
Show resolved Hide resolved
attention_mask_ptr[attention_mask.get_size() - 2] = 0;
attention_mask_ptr[attention_mask.get_size() - 1] = 1;
// NB: Zero past_key_values.*.decoder.value tensors
Expand All @@ -227,13 +214,15 @@ void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferReq
};

int64_t detect_language(ov::Tensor& encoder_hidden_state,
ov::InferRequest decoder,
ov::genai::DecoderCache& decoder_cache,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::RawPerfMetrics& raw_metrics) {
auto decoder = decoder_cache.get_model(1);

decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});

std::vector<int32_t> init_ids{static_cast<int32_t>(config.decoder_start_token_id)};
set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id);
set_decoder_input_ids(decoder, init_ids);

const auto infer_start = std::chrono::steady_clock::now();
decoder.infer();
Expand All @@ -259,7 +248,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
}

std::vector<int32_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
ov::InferRequest& decoder,
ov::genai::DecoderCache& decoder_cache,
const ov::genai::WhisperGenerationConfig& config,
const bool return_timestamps,
ov::genai::RawPerfMetrics& raw_metrics) {
Expand All @@ -279,7 +268,7 @@ std::vector<int32_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
language_token_id = static_cast<int32_t>(config.lang_to_id.at(language));
}
} else {
language_token_id = detect_language(encoder_hidden_state, decoder, config, raw_metrics);
language_token_id = detect_language(encoder_hidden_state, decoder_cache, config, raw_metrics);
}

int32_t task_token_id = static_cast<int32_t>(config.transcribe_token_id);
Expand Down Expand Up @@ -318,7 +307,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
return {false, output_tokens};
}

prepare_decoder_with_past(models.decoder_with_past, models.decoder);
prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size());

for (size_t i = 0; i < max_new_tokens - 1; i++) {
auto output_token = decode_with_past(models.decoder_with_past,
Expand Down Expand Up @@ -353,36 +342,6 @@ bool check_decoder_model_compatibility(const std::shared_ptr<ov::Model>& decoder
return false;
}

void add_attention_mask_input_for_decoder(std::shared_ptr<ov::Model> model) {
using namespace ov::pass::pattern;
using namespace ov::op;
class AttentionMaskInput : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("AttentionMaskInput");

AttentionMaskInput(std::shared_ptr<ov::Model> model) {
auto range = wrap_type<v4::Range>();
auto convert = wrap_type<v0::Convert>({range});
auto convert1 = wrap_type<v0::Convert>({convert});
auto greater = wrap_type<v1::Greater>({convert1, any_input()});
auto convert2 = wrap_type<v0::Convert>({greater});

register_matcher(std::make_shared<Matcher>(convert2, this->get_type_info().name), [model](Matcher& m) {
auto node = m.get_match_root();
auto attention_mask = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1});
attention_mask->get_output_tensor(0).set_names({"attention_mask"});
model->add_parameters({attention_mask});
ov::replace_node(node, attention_mask);
return false;
});
}
};

ov::pass::Manager pm;
pm.register_pass<AttentionMaskInput>(model);
pm.run_passes(model);
}

void add_attention_mask_input(std::shared_ptr<ov::Model> model) {
using namespace ov::pass::pattern;
using namespace ov::op;
Expand Down Expand Up @@ -467,6 +426,10 @@ void reshape_to_static_encoder(std::shared_ptr<ov::Model> model, const size_t fe
model->reshape(new_shapes);
}

void reshape_input_ids(std::shared_ptr<ov::Model> model, const uint32_t input_size) {
model->reshape({{"input_ids", ov::PartialShape({1, input_size})}});
}

void preprocess_encoder(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor preprocessor(model);

Expand All @@ -489,7 +452,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
preprocessor.input("attention_mask").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // ()
TolyaTalamanov marked this conversation as resolved.
Show resolved Hide resolved
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32);
} else if (tensor.get_any_name().find("past_key_values") != std::string::npos) {
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type();
Expand Down Expand Up @@ -541,6 +504,19 @@ std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::M
namespace ov {
namespace genai {

ov::InferRequest DecoderCache::get_model(uint8_t input_ids_size) {
if (m_cache.find(input_ids_size) == m_cache.cend()) {
reshape_input_ids(m_decoder_model, input_ids_size);

ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_cache.emplace(input_ids_size, compiled_model.create_infer_request());
}

return m_cache.at(input_ids_size);
}

WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesystem::path& models_path,
const ov::AnyMap& properties)
: WhisperPipelineImplBase{models_path} {
Expand All @@ -550,20 +526,14 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
auto decoder_model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
auto decoder_with_past_model = core.read_model(models_path / "openvino_decoder_with_past_model.xml", {}, properties);

add_attention_mask_input_for_decoder(decoder_model);
add_attention_mask_input(decoder_with_past_model);

// TODO: Support models produced by optimum-cli
if (!check_decoder_model_compatibility(decoder_model)) {
OPENVINO_THROW("StaticWhisperPipeline expects decoder model has \"attention_mask\" input!");
}

size_t max_sequence_length = 448;

reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);

auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape);
reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);

// Replace KV-tensors for the entire cache to tensors only for new token
Expand All @@ -577,9 +547,10 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
compiled_model = core.compile_model(encoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();
compiled_model = core.compile_model(decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_models.decoder = compiled_model.create_infer_request();

// Will compile decoder model when it's needed
m_decoder_cache = DecoderCache(decoder_model);

compiled_model = core.compile_model(decoder_with_past_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
m_models.decoder_with_past = compiled_model.create_infer_request();
Expand Down Expand Up @@ -654,7 +625,10 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(

// prepare init_ids just once for whole input
if (init_ids.empty()) {
init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics);
init_ids = prepare_init_ids(hidden_state_tensor, m_decoder_cache, config, return_timestamps, raw_metrics);

// Get decoder with size of input_ids
m_models.decoder = m_decoder_cache.get_model(init_ids.size());
}

auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,
Expand Down
12 changes: 12 additions & 0 deletions src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
namespace ov {
namespace genai {

class DecoderCache {
public:
DecoderCache() = default;
DecoderCache(std::shared_ptr<ov::Model> model) : m_decoder_model(model) {}

ov::InferRequest get_model(uint8_t input_ids_size);
private:
std::unordered_map<uint8_t, ov::InferRequest> m_cache;
std::shared_ptr<ov::Model> m_decoder_model;
};

class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {
public:
StaticWhisperPipeline(const std::filesystem::path& model_path, const ov::AnyMap& properties);
Expand All @@ -25,6 +36,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi

private:
WhisperInitializedModels m_models;
DecoderCache m_decoder_cache;
};

} // namespace genai
Expand Down
Loading