Skip to content

Commit

Permalink
Added decoder cache, removed decoder's attn mask
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae committed Jan 9, 2025
1 parent 7e8bbfe commit 3c901d7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 58 deletions.
95 changes: 38 additions & 57 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,15 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con
}
}

void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids,
const int64_t pad_token) {
void set_decoder_input_ids(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids) {
auto input_ids_tensor = decoder.get_tensor("input_ids");
auto attention_mask_tensor = decoder.get_tensor("attention_mask");

const size_t seq_length = input_ids_tensor.get_shape()[1];

OPENVINO_ASSERT(seq_length >= init_ids.size());

// pad right
// input_ids [token, token, token, pad_token]
// attention_mask [1, 1, 1, 0]
auto input_ids_data = input_ids_tensor.data<int32_t>();
std::copy(init_ids.begin(), init_ids.end(), input_ids_data);

auto attention_mask_data = attention_mask_tensor.data<ov::float16>();
std::fill_n(attention_mask_data, init_ids.size(), 1u);
std::fill(attention_mask_data + init_ids.size(), attention_mask_data + attention_mask_tensor.get_size(), 0u);
}

int64_t decode(ov::Tensor& encoder_hidden_state,
Expand All @@ -151,7 +141,7 @@ int64_t decode(ov::Tensor& encoder_hidden_state,
const bool return_timestamps = false) {
// NB: Fill decoder inputs
encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states"));
set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id);
set_decoder_input_ids(decoder, init_ids);

ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics);

Expand Down Expand Up @@ -230,7 +220,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});

std::vector<int32_t> init_ids{static_cast<int32_t>(config.decoder_start_token_id)};
set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id);
set_decoder_input_ids(decoder, init_ids);

const auto infer_start = std::chrono::steady_clock::now();
decoder.infer();
Expand Down Expand Up @@ -350,36 +340,6 @@ bool check_decoder_model_compatibility(const std::shared_ptr<ov::Model>& decoder
return false;
}

void add_attention_mask_input_for_decoder(std::shared_ptr<ov::Model> model) {
using namespace ov::pass::pattern;
using namespace ov::op;
class AttentionMaskInput : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("AttentionMaskInput");

AttentionMaskInput(std::shared_ptr<ov::Model> model) {
auto range = wrap_type<v4::Range>();
auto convert = wrap_type<v0::Convert>({range});
auto convert1 = wrap_type<v0::Convert>({convert});
auto greater = wrap_type<v1::Greater>({convert1, any_input()});
auto convert2 = wrap_type<v0::Convert>({greater});

register_matcher(std::make_shared<Matcher>(convert2, this->get_type_info().name), [model](Matcher& m) {
auto node = m.get_match_root();
auto attention_mask = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1});
attention_mask->get_output_tensor(0).set_names({"attention_mask"});
model->add_parameters({attention_mask});
ov::replace_node(node, attention_mask);
return false;
});
}
};

ov::pass::Manager pm;
pm.register_pass<AttentionMaskInput>(model);
pm.run_passes(model);
}

void add_attention_mask_input(std::shared_ptr<ov::Model> model) {
using namespace ov::pass::pattern;
using namespace ov::op;
Expand Down Expand Up @@ -464,6 +424,24 @@ void reshape_to_static_encoder(std::shared_ptr<ov::Model> model, const size_t fe
model->reshape(new_shapes);
}

void reshape_input_ids(std::shared_ptr<ov::Model> model, const uint32_t input_size) {
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
ov::PartialShape new_shape;
if (input_name.find("input_ids") != std::string::npos) {
new_shape = ov::PartialShape({1, input_size});
} else if (input_name.find("attention_mask") != std::string::npos) {
new_shape = ov::PartialShape({1, input_size + 1});
} else {
new_shape = input.get_partial_shape();
}
new_shapes.emplace(input_name, new_shape);
}

model->reshape(new_shapes);
}

void preprocess_encoder(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor preprocessor(model);

Expand Down Expand Up @@ -538,6 +516,17 @@ std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::M
namespace ov {
namespace genai {

ov::CompiledModel DecoderCache::get_model(uint8_t input_ids_size) {
if (m_cache.find(input_ids_size) == m_cache.cend()) {
reshape_input_ids(m_decoder_model, input_ids_size);

ov::Core core = utils::singleton_core();
m_cache.insert({input_ids_size, core.compile_model(m_decoder_model, "NPU")});
}

return m_cache.at(input_ids_size);
}

WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesystem::path& models_path,
const ov::AnyMap& properties)
: WhisperPipelineImplBase{models_path} {
Expand All @@ -547,14 +536,8 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
auto decoder_model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
auto decoder_with_past_model = core.read_model(models_path / "openvino_decoder_with_past_model.xml", {}, properties);

add_attention_mask_input_for_decoder(decoder_model);
add_attention_mask_input(decoder_with_past_model);

// TODO: Support models produced by optimum-cli
if (!check_decoder_model_compatibility(decoder_model)) {
OPENVINO_THROW("StaticWhisperPipeline expects decoder model has \"attention_mask\" input!");
}

size_t max_sequence_length = 448;

reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);
Expand All @@ -575,8 +558,8 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();

m_decoder_model = decoder_model; // for reshape in generate() when we get number of input tokens
compiled_model = core.compile_model(decoder_model, "NPU");
m_decoder_cache = DecoderCache(decoder_model);
compiled_model = m_decoder_cache.get_model(1);
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_models.decoder = compiled_model.create_infer_request();

Expand Down Expand Up @@ -654,13 +637,11 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(

// prepare init_ids just once for whole input
if (init_ids.empty()) {
OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_shape().back() == 1);
OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_size() == 1);
init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics);

// Reshape decoder model for the number of input tokens
ov::Core core = utils::singleton_core();
reshape_to_static(m_decoder_model, init_ids.size(), init_ids.size(), hidden_state_tensor.get_shape());
m_models.decoder = core.compile_model(m_decoder_model, "NPU").create_infer_request();
// Get decoder with size of input_ids
m_models.decoder = m_decoder_cache.get_model(init_ids.size()).create_infer_request();
}

auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,
Expand Down
13 changes: 12 additions & 1 deletion src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
namespace ov {
namespace genai {

class DecoderCache {
public:
DecoderCache() {}
DecoderCache(std::shared_ptr<ov::Model> model) : m_decoder_model(model) {}

ov::CompiledModel get_model(uint8_t input_ids_size);
private:
std::unordered_map<uint8_t, ov::CompiledModel> m_cache;
std::shared_ptr<ov::Model> m_decoder_model;
};

class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {
public:
StaticWhisperPipeline(const std::filesystem::path& model_path, const ov::AnyMap& properties);
Expand All @@ -25,7 +36,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi

private:
WhisperInitializedModels m_models;
std::shared_ptr<ov::Model> m_decoder_model;
DecoderCache m_decoder_cache;
};

} // namespace genai
Expand Down

0 comments on commit 3c901d7

Please sign in to comment.