Skip to content

Commit

Permalink
Whisper: Fix decoder inputs for static pipeline (#1469)
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae authored Jan 10, 2025
1 parent 8e56a36 commit 77611da
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 66 deletions.
106 changes: 40 additions & 66 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,15 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con
}
}

void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids,
const int64_t pad_token) {
void set_decoder_input_ids(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids) {
auto input_ids_tensor = decoder.get_tensor("input_ids");
auto attention_mask_tensor = decoder.get_tensor("attention_mask");

const size_t seq_length = input_ids_tensor.get_shape()[1];

OPENVINO_ASSERT(seq_length >= init_ids.size());

// pad right
// input_ids [token, token, token, pad_token]
// attention_mask [1, 1, 1, 0]
auto input_ids_data = input_ids_tensor.data<int32_t>();
std::copy(init_ids.begin(), init_ids.end(), input_ids_data);
std::fill(input_ids_data + init_ids.size(),
input_ids_data + input_ids_tensor.get_size(),
static_cast<int32_t>(pad_token));

auto attention_mask_data = attention_mask_tensor.data<ov::float16>();
std::fill_n(attention_mask_data, init_ids.size(), 1u);
std::fill(attention_mask_data + init_ids.size(), attention_mask_data + attention_mask_tensor.get_size(), 0u);
}

int64_t decode(ov::Tensor& encoder_hidden_state,
Expand All @@ -154,7 +141,7 @@ int64_t decode(ov::Tensor& encoder_hidden_state,
const bool return_timestamps = false) {
// NB: Fill decoder inputs
encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states"));
set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id);
set_decoder_input_ids(decoder, init_ids);

ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics);

Expand Down Expand Up @@ -210,13 +197,13 @@ void zero_past_key_values(ov::InferRequest& request) {
}
}

void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder) {
void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder, const size_t init_ids_size) {
// NB: Prepare attetion mask to be in a format [0, 0, 0, 1, 1, 1, 1, ..., 0, 1]
// Mask should be inverted for decoder_with_past
auto attention_mask = decoder_with_past.get_tensor("attention_mask");
auto* attention_mask_ptr = attention_mask.data<ov::float16>();
std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0);
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1);
std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0);
std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1);
attention_mask_ptr[attention_mask.get_size() - 2] = 0;
attention_mask_ptr[attention_mask.get_size() - 1] = 1;
// NB: Zero past_key_values.*.decoder.value tensors
Expand All @@ -227,13 +214,15 @@ void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferReq
};

int64_t detect_language(ov::Tensor& encoder_hidden_state,
ov::InferRequest decoder,
ov::genai::DecoderCache& decoder_cache,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::RawPerfMetrics& raw_metrics) {
auto decoder = decoder_cache.get_model(1);

decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});

std::vector<int32_t> init_ids{static_cast<int32_t>(config.decoder_start_token_id)};
set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id);
set_decoder_input_ids(decoder, init_ids);

const auto infer_start = std::chrono::steady_clock::now();
decoder.infer();
Expand All @@ -259,7 +248,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
}

std::vector<int32_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
ov::InferRequest& decoder,
ov::genai::DecoderCache& decoder_cache,
const ov::genai::WhisperGenerationConfig& config,
const bool return_timestamps,
ov::genai::RawPerfMetrics& raw_metrics) {
Expand All @@ -279,7 +268,7 @@ std::vector<int32_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
language_token_id = static_cast<int32_t>(config.lang_to_id.at(language));
}
} else {
language_token_id = detect_language(encoder_hidden_state, decoder, config, raw_metrics);
language_token_id = detect_language(encoder_hidden_state, decoder_cache, config, raw_metrics);
}

int32_t task_token_id = static_cast<int32_t>(config.transcribe_token_id);
Expand Down Expand Up @@ -318,7 +307,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
return {false, output_tokens};
}

prepare_decoder_with_past(models.decoder_with_past, models.decoder);
prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size());

for (size_t i = 0; i < max_new_tokens - 1; i++) {
auto output_token = decode_with_past(models.decoder_with_past,
Expand Down Expand Up @@ -353,36 +342,6 @@ bool check_decoder_model_compatibility(const std::shared_ptr<ov::Model>& decoder
return false;
}

void add_attention_mask_input_for_decoder(std::shared_ptr<ov::Model> model) {
using namespace ov::pass::pattern;
using namespace ov::op;
class AttentionMaskInput : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("AttentionMaskInput");

AttentionMaskInput(std::shared_ptr<ov::Model> model) {
auto range = wrap_type<v4::Range>();
auto convert = wrap_type<v0::Convert>({range});
auto convert1 = wrap_type<v0::Convert>({convert});
auto greater = wrap_type<v1::Greater>({convert1, any_input()});
auto convert2 = wrap_type<v0::Convert>({greater});

register_matcher(std::make_shared<Matcher>(convert2, this->get_type_info().name), [model](Matcher& m) {
auto node = m.get_match_root();
auto attention_mask = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1});
attention_mask->get_output_tensor(0).set_names({"attention_mask"});
model->add_parameters({attention_mask});
ov::replace_node(node, attention_mask);
return false;
});
}
};

ov::pass::Manager pm;
pm.register_pass<AttentionMaskInput>(model);
pm.run_passes(model);
}

void add_attention_mask_input(std::shared_ptr<ov::Model> model) {
using namespace ov::pass::pattern;
using namespace ov::op;
Expand Down Expand Up @@ -467,6 +426,10 @@ void reshape_to_static_encoder(std::shared_ptr<ov::Model> model, const size_t fe
model->reshape(new_shapes);
}

void reshape_input_ids(std::shared_ptr<ov::Model> model, const uint32_t input_size) {
model->reshape({{"input_ids", ov::PartialShape({1, input_size})}});
}

void preprocess_encoder(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor preprocessor(model);

Expand All @@ -489,7 +452,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
preprocessor.input("attention_mask").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // ()
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32);
} else if (tensor.get_any_name().find("past_key_values") != std::string::npos) {
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type();
Expand Down Expand Up @@ -541,6 +504,19 @@ std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::M
namespace ov {
namespace genai {

ov::InferRequest DecoderCache::get_model(uint8_t input_ids_size) {
if (m_cache.find(input_ids_size) == m_cache.cend()) {
reshape_input_ids(m_decoder_model, input_ids_size);

ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_cache.emplace(input_ids_size, compiled_model.create_infer_request());
}

return m_cache.at(input_ids_size);
}

WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesystem::path& models_path,
const ov::AnyMap& properties)
: WhisperPipelineImplBase{models_path} {
Expand All @@ -550,20 +526,14 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
auto decoder_model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
auto decoder_with_past_model = core.read_model(models_path / "openvino_decoder_with_past_model.xml", {}, properties);

add_attention_mask_input_for_decoder(decoder_model);
add_attention_mask_input(decoder_with_past_model);

// TODO: Support models produced by optimum-cli
if (!check_decoder_model_compatibility(decoder_model)) {
OPENVINO_THROW("StaticWhisperPipeline expects decoder model has \"attention_mask\" input!");
}

size_t max_sequence_length = 448;

reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);

auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape);
reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);

// Replace KV-tensors for the entire cache to tensors only for new token
Expand All @@ -577,9 +547,10 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
compiled_model = core.compile_model(encoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();
compiled_model = core.compile_model(decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_models.decoder = compiled_model.create_infer_request();

// Will compile decoder model when it's needed
m_decoder_cache = DecoderCache(decoder_model);

compiled_model = core.compile_model(decoder_with_past_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
m_models.decoder_with_past = compiled_model.create_infer_request();
Expand Down Expand Up @@ -654,7 +625,10 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(

// prepare init_ids just once for whole input
if (init_ids.empty()) {
init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics);
init_ids = prepare_init_ids(hidden_state_tensor, m_decoder_cache, config, return_timestamps, raw_metrics);

// Get decoder with size of input_ids
m_models.decoder = m_decoder_cache.get_model(init_ids.size());
}

auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,
Expand Down
12 changes: 12 additions & 0 deletions src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
namespace ov {
namespace genai {

class DecoderCache {
public:
DecoderCache() = default;
DecoderCache(std::shared_ptr<ov::Model> model) : m_decoder_model(model) {}

ov::InferRequest get_model(uint8_t input_ids_size);
private:
std::unordered_map<uint8_t, ov::InferRequest> m_cache;
std::shared_ptr<ov::Model> m_decoder_model;
};

class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {
public:
StaticWhisperPipeline(const std::filesystem::path& model_path, const ov::AnyMap& properties);
Expand All @@ -25,6 +36,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi

private:
WhisperInitializedModels m_models;
DecoderCache m_decoder_cache;
};

} // namespace genai
Expand Down

0 comments on commit 77611da

Please sign in to comment.