-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper pipeline: support stateful decoder #1474
Changes from all commits
68d3e48
ade7313
b2df4a6
17e3ea7
806b01a
e041a33
9502d9b
6c30fa4
7600072
f870a4c
e38cf5c
acc656f
aa0f742
3728884
445ce5a
5bdd695
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Copyright (C) 2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
#include <iostream> | ||
#include <string> | ||
|
||
namespace ov::genai { | ||
|
||
class Logger { | ||
public: | ||
static void warn(std::string message) { | ||
std::cout << "[WARN] " << message << '\n'; | ||
}; | ||
}; | ||
|
||
} // namespace ov::genai |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "decoder.hpp" | ||
|
||
#include <filesystem> | ||
|
||
#include "statefull_decoder.hpp" | ||
#include "utils.hpp" | ||
#include "with_past_decoder.hpp" | ||
|
||
namespace ov::genai { | ||
std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties) { | ||
bool has_decoder_with_past = std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml"); | ||
|
||
if (has_decoder_with_past) { | ||
return std::make_shared<WhisperWithPastDecoder>(models_path, device, properties); | ||
} | ||
|
||
return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties); | ||
} | ||
|
||
WhisperDecoder::~WhisperDecoder() = default; | ||
} // namespace ov::genai |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <filesystem> | ||
|
||
#include "openvino/genai/whisper_generation_config.hpp" | ||
#include "openvino/runtime/core.hpp" | ||
|
||
namespace ov::genai { | ||
class WhisperDecoder { | ||
public: | ||
static std::shared_ptr<WhisperDecoder> from_path(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties); | ||
|
||
virtual std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) = 0; | ||
|
||
virtual std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) = 0; | ||
|
||
virtual void reset_state() = 0; | ||
|
||
virtual ~WhisperDecoder(); | ||
}; | ||
} // namespace ov::genai |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "statefull_decoder.hpp" | ||
|
||
#include "utils.hpp" | ||
|
||
namespace ov::genai { | ||
WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties) { | ||
ov::Core core = utils::singleton_core(); | ||
|
||
auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties); | ||
|
||
utils::print_compiled_model_properties(compiled_model, "whisper decoder model"); | ||
m_request = compiled_model.create_infer_request(); | ||
} | ||
|
||
std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) { | ||
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0); | ||
|
||
int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); | ||
|
||
reset_state(); | ||
|
||
return {output_token, infer_ms}; | ||
} | ||
|
||
std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) { | ||
m_request.set_tensor("encoder_hidden_states", encoder_hidden_state); | ||
|
||
ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data()); | ||
m_request.set_tensor("input_ids", input_ids_tensor); | ||
|
||
ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position"); | ||
cache_position_tensor.set_shape({input_ids.size()}); | ||
|
||
auto cache_data = cache_position_tensor.data<int64_t>(); | ||
std::iota(cache_data, cache_data + cache_position_tensor.get_size(), cache_position); | ||
|
||
m_request.get_tensor("beam_idx").set_shape({1}); | ||
m_request.get_tensor("beam_idx").data<int32_t>()[0] = 0; | ||
|
||
const auto infer_start = std::chrono::steady_clock::now(); | ||
m_request.infer(); | ||
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); | ||
|
||
auto output_tensor = m_request.get_tensor("logits"); | ||
|
||
return {output_tensor, infer_ms}; | ||
}; | ||
|
||
void WhisperStatefullDecoder::reset_state() { | ||
m_request.reset_state(); | ||
} | ||
} // namespace ov::genai |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "decoder.hpp" | ||
#include "openvino/runtime/core.hpp" | ||
|
||
namespace ov::genai { | ||
|
||
class WhisperStatefullDecoder : public WhisperDecoder { | ||
public: | ||
WhisperStatefullDecoder(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties); | ||
|
||
std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) override; | ||
|
||
std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) override; | ||
|
||
void reset_state() override; | ||
|
||
private: | ||
ov::InferRequest m_request; | ||
}; | ||
} // namespace ov::genai |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "with_past_decoder.hpp" | ||
|
||
#include <regex> | ||
|
||
#include "logger.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace { | ||
void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) { | ||
// source outputs: | ||
// present.0.decoder.key | ||
// present.0.decoder.value | ||
// present.0.encoder.key | ||
// present.0.encoder.value | ||
|
||
// dest inputs: | ||
// past_key_values.0.decoder.key | ||
// past_key_values.0.decoder.value | ||
// past_key_values.0.encoder.key | ||
// past_key_values.0.encoder.value | ||
|
||
for (auto& source_output : source.get_compiled_model().outputs()) { | ||
std::string source_output_name = source_output.get_any_name(); | ||
if (source_output_name.find("logits") != std::string::npos) { | ||
continue; | ||
} | ||
|
||
std::string with_past_input_name = | ||
std::regex_replace(source_output_name, std::regex("present"), "past_key_values"); | ||
|
||
auto kv_tensor = source.get_tensor(source_output_name); | ||
dest.set_tensor(with_past_input_name, ov::Tensor{kv_tensor}); | ||
} | ||
} | ||
} // namespace | ||
|
||
namespace ov::genai { | ||
WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& models_path, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ilya-lavrenov @Wovchena I want to add deprecation note for this ctor. I saw There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't have a better way than a message in runtime |
||
const std::string& device, | ||
const ov::AnyMap& properties) { | ||
Logger::warn("Whisper decoder models with past is deprecated. Support will be removed in 2026.0.0 release.\n" | ||
"To obtain stateful decoder model use latest `optimum-intel` package:\n" | ||
"pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git\n" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can provide the minimum version containing stateful whisper by specifying |
||
"optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny whisper-tiny"); | ||
ov::Core core = utils::singleton_core(); | ||
|
||
auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties); | ||
utils::print_compiled_model_properties(compiled_model, "whisper decoder model"); | ||
m_request_decoder = compiled_model.create_infer_request(); | ||
|
||
compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties); | ||
utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model"); | ||
m_request_decoder_with_past = compiled_model.create_infer_request(); | ||
} | ||
|
||
std::pair<int64_t, float> WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) { | ||
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0); | ||
|
||
int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); | ||
|
||
reset_state(); | ||
|
||
return {output_token, infer_ms}; | ||
} | ||
|
||
std::pair<ov::Tensor, float> WhisperWithPastDecoder::decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) { | ||
const bool initial_step = cache_position == 0; | ||
ov::InferRequest& request = initial_step ? m_request_decoder : m_request_decoder_with_past; | ||
|
||
request.set_tensor("encoder_hidden_states", encoder_hidden_state); | ||
|
||
const ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data()); | ||
request.set_tensor("input_ids", input_ids_tensor); | ||
|
||
if (!initial_step) { | ||
ov::Tensor cache_position_tensor = request.get_tensor("cache_position"); | ||
cache_position_tensor.set_shape({1}); | ||
cache_position_tensor.data<int64_t>()[0] = cache_position; | ||
} | ||
|
||
const auto infer_start = std::chrono::steady_clock::now(); | ||
request.infer(); | ||
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); | ||
|
||
auto output_tensor = request.get_tensor("logits"); | ||
|
||
if (initial_step) { | ||
set_past_key_value(m_request_decoder, m_request_decoder_with_past); | ||
} else if (!m_decoder_with_past_kv_value_set) { | ||
set_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past); | ||
m_decoder_with_past_kv_value_set = true; | ||
} | ||
|
||
return {output_tensor, infer_ms}; | ||
} | ||
|
||
void WhisperWithPastDecoder::reset_state() { | ||
m_request_decoder_with_past.reset_state(); | ||
m_decoder_with_past_kv_value_set = false; | ||
} | ||
} // namespace ov::genai |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "decoder.hpp" | ||
#include "openvino/runtime/core.hpp" | ||
|
||
namespace ov::genai { | ||
|
||
class WhisperWithPastDecoder : public WhisperDecoder { | ||
public: | ||
WhisperWithPastDecoder(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties); | ||
|
||
std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) override; | ||
|
||
std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) override; | ||
|
||
void reset_state() override; | ||
|
||
private: | ||
ov::InferRequest m_request_decoder; | ||
ov::InferRequest m_request_decoder_with_past; | ||
bool m_decoder_with_past_kv_value_set = false; | ||
}; | ||
|
||
} // namespace ov::genai |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.