Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper: Fix decoder inputs for static pipeline #1469

Merged
merged 5 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 38 additions & 62 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,15 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con
}
}

void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids,
const int64_t pad_token) {
void set_decoder_input_ids(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids) {
auto input_ids_tensor = decoder.get_tensor("input_ids");
auto attention_mask_tensor = decoder.get_tensor("attention_mask");

const size_t seq_length = input_ids_tensor.get_shape()[1];

OPENVINO_ASSERT(seq_length >= init_ids.size());

// pad right
// input_ids [token, token, token, pad_token]
// attention_mask [1, 1, 1, 0]
auto input_ids_data = input_ids_tensor.data<int32_t>();
std::copy(init_ids.begin(), init_ids.end(), input_ids_data);
std::fill(input_ids_data + init_ids.size(),
input_ids_data + input_ids_tensor.get_size(),
static_cast<int32_t>(pad_token));

auto attention_mask_data = attention_mask_tensor.data<ov::float16>();
std::fill_n(attention_mask_data, init_ids.size(), 1u);
std::fill(attention_mask_data + init_ids.size(), attention_mask_data + attention_mask_tensor.get_size(), 0u);
}

int64_t decode(ov::Tensor& encoder_hidden_state,
Expand All @@ -154,7 +141,7 @@ int64_t decode(ov::Tensor& encoder_hidden_state,
const bool return_timestamps = false) {
// NB: Fill decoder inputs
encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states"));
set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id);
set_decoder_input_ids(decoder, init_ids);

ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics);

Expand Down Expand Up @@ -210,13 +197,13 @@ void zero_past_key_values(ov::InferRequest& request) {
}
}

void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder) {
void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder, const size_t init_ids_size) {
// NB: Prepare attetion mask to be in a format [0, 0, 0, 1, 1, 1, 1, ..., 0, 1]
// Mask should be inverted for decoder_with_past
auto attention_mask = decoder_with_past.get_tensor("attention_mask");
auto* attention_mask_ptr = attention_mask.data<ov::float16>();
std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0);
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1);
std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0);
std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1);
TolyaTalamanov marked this conversation as resolved.
Show resolved Hide resolved
attention_mask_ptr[attention_mask.get_size() - 2] = 0;
attention_mask_ptr[attention_mask.get_size() - 1] = 1;
// NB: Zero past_key_values.*.decoder.value tensors
Expand All @@ -233,7 +220,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});

std::vector<int32_t> init_ids{static_cast<int32_t>(config.decoder_start_token_id)};
set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id);
set_decoder_input_ids(decoder, init_ids);

const auto infer_start = std::chrono::steady_clock::now();
decoder.infer();
Expand Down Expand Up @@ -318,7 +305,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
return {false, output_tokens};
}

prepare_decoder_with_past(models.decoder_with_past, models.decoder);
prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size());

for (size_t i = 0; i < max_new_tokens - 1; i++) {
auto output_token = decode_with_past(models.decoder_with_past,
Expand Down Expand Up @@ -353,36 +340,6 @@ bool check_decoder_model_compatibility(const std::shared_ptr<ov::Model>& decoder
return false;
}

void add_attention_mask_input_for_decoder(std::shared_ptr<ov::Model> model) {
using namespace ov::pass::pattern;
using namespace ov::op;
class AttentionMaskInput : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("AttentionMaskInput");

AttentionMaskInput(std::shared_ptr<ov::Model> model) {
auto range = wrap_type<v4::Range>();
auto convert = wrap_type<v0::Convert>({range});
auto convert1 = wrap_type<v0::Convert>({convert});
auto greater = wrap_type<v1::Greater>({convert1, any_input()});
auto convert2 = wrap_type<v0::Convert>({greater});

register_matcher(std::make_shared<Matcher>(convert2, this->get_type_info().name), [model](Matcher& m) {
auto node = m.get_match_root();
auto attention_mask = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1});
attention_mask->get_output_tensor(0).set_names({"attention_mask"});
model->add_parameters({attention_mask});
ov::replace_node(node, attention_mask);
return false;
});
}
};

ov::pass::Manager pm;
pm.register_pass<AttentionMaskInput>(model);
pm.run_passes(model);
}

void add_attention_mask_input(std::shared_ptr<ov::Model> model) {
using namespace ov::pass::pattern;
using namespace ov::op;
Expand Down Expand Up @@ -467,6 +424,10 @@ void reshape_to_static_encoder(std::shared_ptr<ov::Model> model, const size_t fe
model->reshape(new_shapes);
}

void reshape_input_ids(std::shared_ptr<ov::Model> model, const uint32_t input_size) {
model->reshape({{"input_ids", ov::PartialShape({1, input_size})}});
}

void preprocess_encoder(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor preprocessor(model);

Expand All @@ -489,7 +450,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
preprocessor.input("attention_mask").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // ()
TolyaTalamanov marked this conversation as resolved.
Show resolved Hide resolved
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32);
} else if (tensor.get_any_name().find("past_key_values") != std::string::npos) {
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type();
Expand Down Expand Up @@ -541,6 +502,23 @@ std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::M
namespace ov {
namespace genai {

ov::CompiledModel DecoderCache::get_model(uint8_t input_ids_size) {
if (m_cache.find(input_ids_size) == m_cache.cend()) {
if (m_decoder_model->is_dynamic()) { // model is dynamic, reshaping it to static
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, I don't expect we need this check...

Let's discuss locally

reshape_to_static(m_decoder_model, input_ids_size, input_ids_size, m_lhs_shape);
} else {
reshape_input_ids(m_decoder_model, input_ids_size);
}

ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_cache.emplace(input_ids_size, compiled_model);
}

return m_cache.at(input_ids_size);
}

WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesystem::path& models_path,
const ov::AnyMap& properties)
: WhisperPipelineImplBase{models_path} {
Expand All @@ -550,20 +528,13 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
auto decoder_model = core.read_model(models_path / "openvino_decoder_model.xml", {}, properties);
auto decoder_with_past_model = core.read_model(models_path / "openvino_decoder_with_past_model.xml", {}, properties);

add_attention_mask_input_for_decoder(decoder_model);
add_attention_mask_input(decoder_with_past_model);

// TODO: Support models produced by optimum-cli
if (!check_decoder_model_compatibility(decoder_model)) {
OPENVINO_THROW("StaticWhisperPipeline expects decoder model has \"attention_mask\" input!");
}

size_t max_sequence_length = 448;

reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);

auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);

// Replace KV-tensors for the entire cache to tensors only for new token
Expand All @@ -577,9 +548,10 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
compiled_model = core.compile_model(encoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();
compiled_model = core.compile_model(decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_models.decoder = compiled_model.create_infer_request();

// Will compile decoder model when it's needed
m_decoder_cache = DecoderCache(decoder_model, last_hidden_state_shape);

compiled_model = core.compile_model(decoder_with_past_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
m_models.decoder_with_past = compiled_model.create_infer_request();
Expand Down Expand Up @@ -654,7 +626,11 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(

// prepare init_ids just once for whole input
if (init_ids.empty()) {
m_models.decoder = m_decoder_cache.get_model(1).create_infer_request(); // for detect_language()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the model compiled upfront again. What if it's not even needed and language already known?

I believe you need it somewhere here:

language_token_id = detect_language(encoder_hidden_state, decoder, config, raw_metrics);

init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics);

// Get decoder with size of input_ids
m_models.decoder = m_decoder_cache.get_model(init_ids.size()).create_infer_request();
}

auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,
Expand Down
15 changes: 15 additions & 0 deletions src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@
namespace ov {
namespace genai {

class DecoderCache {
public:
DecoderCache() = default;
DecoderCache(std::shared_ptr<ov::Model> model, ov::PartialShape shape)
: m_decoder_model(model)
, m_lhs_shape(shape) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, decoder model has two input layers:

  1. encoder_hidden_states - does it change? I believe it depends on the encoder model and once we know it it's no longer change?
  2. decoder_input_ids - This is what we track in hash map. Changes depends on the case

Can we set encoder_hidden_states in the ctor? If so, I believe we don't need to change if model dynamic or not in the get() method

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the size of encoder_hiddne_states depends only on feature_size and it's set in StaticWhisperPipeline and no longer changes since then:

reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);
. Can you confirm this?

If so, I believe we definitely can set shape for encoder_hidden_states once.

// Also, it should probably be `ov::Shape` rather than ov::PartialShape as we don't have dynamism (input is fixed)
DecoderCache(std::shared_ptr<ov::Model> model, ov::PartialShape encoder_hidden_shape) {
    model->reshape({{"encoder_hidden_states", encoder_hidden_shape}});
}
DecoderCache::get(ov::CompiledModel DecoderCache::get_model(uint8_t input_ids_size) {
    if (m_cache.counts(input_ids_size) == 0) {
        m_model->reshape({{"decoder_input_ids", ov::Shape({1, input_ids_size})}});
        m_cache.emplace(input_ids_size, core.compile_model(m_model));
    }
    return m_cache.at(input_ids_size);
}


ov::CompiledModel get_model(uint8_t input_ids_size);
private:
std::unordered_map<uint8_t, ov::CompiledModel> m_cache;
std::shared_ptr<ov::Model> m_decoder_model;
ov::PartialShape m_lhs_shape;
};

class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {
public:
StaticWhisperPipeline(const std::filesystem::path& model_path, const ov::AnyMap& properties);
Expand All @@ -25,6 +39,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi

private:
WhisperInitializedModels m_models;
DecoderCache m_decoder_cache;
};

} // namespace genai
Expand Down
Loading