Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae committed Jan 9, 2025
1 parent ac16558 commit 614e6d9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
33 changes: 12 additions & 21 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,7 @@ void reshape_to_static_encoder(std::shared_ptr<ov::Model> model, const size_t fe
}

void reshape_input_ids(std::shared_ptr<ov::Model> model, const uint32_t input_size) {
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
ov::PartialShape new_shape;
if (input_name.find("input_ids") != std::string::npos) {
new_shape = ov::PartialShape({1, input_size});
} else {
new_shape = input.get_partial_shape();
}
new_shapes.emplace(input_name, new_shape);
}

model->reshape(new_shapes);
model->reshape({{"input_ids", ov::PartialShape({1, input_size})}});
}

void preprocess_encoder(std::shared_ptr<ov::Model> model) {
Expand Down Expand Up @@ -516,10 +504,16 @@ namespace genai {

ov::CompiledModel DecoderCache::get_model(uint8_t input_ids_size) {
if (m_cache.find(input_ids_size) == m_cache.cend()) {
reshape_input_ids(m_decoder_model, input_ids_size);
if (m_decoder_model->is_dynamic()) { // model is dynamic, reshaping it to static
reshape_to_static(m_decoder_model, input_ids_size, input_ids_size, m_lhs_shape);
} else {
reshape_input_ids(m_decoder_model, input_ids_size);
}

ov::Core core = utils::singleton_core();
m_cache.insert({input_ids_size, core.compile_model(m_decoder_model, "NPU")});
ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_cache.emplace(input_ids_size, compiled_model);
}

return m_cache.at(input_ids_size);
Expand All @@ -541,7 +535,6 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);

auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape); // for detect_language()
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);

// Replace KV-tensors for the entire cache to tensors only for new token
Expand All @@ -556,10 +549,8 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();

m_decoder_cache = DecoderCache(decoder_model);
compiled_model = m_decoder_cache.get_model(1);
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_models.decoder = compiled_model.create_infer_request();
// Will compile decoder model when it's needed
m_decoder_cache = DecoderCache(decoder_model, last_hidden_state_shape);

compiled_model = core.compile_model(decoder_with_past_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
Expand Down Expand Up @@ -635,7 +626,7 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(

// prepare init_ids just once for whole input
if (init_ids.empty()) {
OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_size() == 1);
m_models.decoder = m_decoder_cache.get_model(1).create_infer_request(); // for detect_language()
init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics);

// Get decoder with size of input_ids
Expand Down
7 changes: 5 additions & 2 deletions src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ namespace genai {

class DecoderCache {
public:
DecoderCache() {}
DecoderCache(std::shared_ptr<ov::Model> model) : m_decoder_model(model) {}
DecoderCache() = default;
DecoderCache(std::shared_ptr<ov::Model> model, ov::PartialShape shape)
: m_decoder_model(model)
, m_lhs_shape(shape) {}

ov::CompiledModel get_model(uint8_t input_ids_size);
private:
std::unordered_map<uint8_t, ov::CompiledModel> m_cache;
std::shared_ptr<ov::Model> m_decoder_model;
ov::PartialShape m_lhs_shape;
};

class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {
Expand Down

0 comments on commit 614e6d9

Please sign in to comment.