-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper pipeline: support stateful decoder #1474
Merged
Wovchena
merged 16 commits into
openvinotoolkit:master
from
as-suvorov:as/statefull_whisper
Jan 13, 2025
Merged
Changes from 14 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
68d3e48
use decoder interface
as-suvorov ade7313
Merge remote-tracking branch 'upstream/master' into as/statefull_whisper
as-suvorov b2df4a6
remove reshape
as-suvorov 17e3ea7
use stateful seq2seq barnch
as-suvorov 806b01a
Address review comments
as-suvorov e041a33
Rename
as-suvorov 9502d9b
Use commit
as-suvorov 6c30fa4
Set tests reqs
as-suvorov 7600072
Merge remote-tracking branch 'upstream/master' into as/statefull_whisper
as-suvorov f870a4c
Merge remote-tracking branch 'upstream/master' into as/statefull_whisper
as-suvorov e38cf5c
Add with_past model tests
as-suvorov acc656f
remove comment
as-suvorov aa0f742
Merge remote-tracking branch 'upstream/master' into as/statefull_whisper
as-suvorov 3728884
bump tokenizers
as-suvorov 445ce5a
Fix typo
as-suvorov 5bdd695
Add deprecation message
as-suvorov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "decoder.hpp" | ||
|
||
#include <filesystem> | ||
|
||
#include "statefull_decoder.hpp" | ||
#include "utils.hpp" | ||
#include "with_past_decoder.hpp" | ||
|
||
namespace ov::genai { | ||
std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties) { | ||
bool has_decoder_with_past = std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml"); | ||
|
||
if (has_decoder_with_past) { | ||
return std::make_shared<WhisperWithPastDecoder>(models_path, device, properties); | ||
} | ||
|
||
return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties); | ||
} | ||
|
||
WhisperDecoder::~WhisperDecoder() = default; | ||
} // namespace ov::genai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <filesystem> | ||
|
||
#include "openvino/genai/whisper_generation_config.hpp" | ||
#include "openvino/runtime/core.hpp" | ||
|
||
namespace ov::genai { | ||
class WhisperDecoder { | ||
public: | ||
static std::shared_ptr<WhisperDecoder> from_path(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties); | ||
|
||
virtual std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) = 0; | ||
|
||
virtual std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) = 0; | ||
|
||
virtual void reset_state() = 0; | ||
|
||
virtual ~WhisperDecoder(); | ||
}; | ||
} // namespace ov::genai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "statefull_decoder.hpp" | ||
|
||
#include "utils.hpp" | ||
|
||
namespace ov::genai { | ||
WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties) { | ||
ov::Core core = utils::singleton_core(); | ||
|
||
auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties); | ||
|
||
utils::print_compiled_model_properties(compiled_model, "whisper decoder model"); | ||
m_request = compiled_model.create_infer_request(); | ||
} | ||
|
||
std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) { | ||
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0); | ||
|
||
int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); | ||
|
||
reset_state(); | ||
|
||
return {output_token, infer_ms}; | ||
} | ||
|
||
std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) { | ||
m_request.set_tensor("encoder_hidden_states", encoder_hidden_state); | ||
|
||
ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data()); | ||
m_request.set_tensor("input_ids", input_ids_tensor); | ||
|
||
ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position"); | ||
cache_position_tensor.set_shape({input_ids.size()}); | ||
|
||
auto cache_data = cache_position_tensor.data<int64_t>(); | ||
std::iota(cache_data, cache_data + cache_position_tensor.get_size(), cache_position); | ||
|
||
m_request.get_tensor("beam_idx").set_shape({1}); | ||
m_request.get_tensor("beam_idx").data<int32_t>()[0] = 0; | ||
|
||
const auto infer_start = std::chrono::steady_clock::now(); | ||
m_request.infer(); | ||
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); | ||
|
||
auto output_tensor = m_request.get_tensor("logits"); | ||
|
||
return {output_tensor, infer_ms}; | ||
}; | ||
|
||
void WhisperStatefullDecoder::reset_state() { | ||
m_request.reset_state(); | ||
} | ||
} // namespace ov::genai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "decoder.hpp" | ||
#include "openvino/runtime/core.hpp" | ||
|
||
namespace ov::genai { | ||
|
||
class WhisperStatefullDecoder : public WhisperDecoder { | ||
public: | ||
WhisperStatefullDecoder(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties); | ||
|
||
std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) override; | ||
|
||
std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) override; | ||
|
||
void reset_state() override; | ||
|
||
private: | ||
ov::InferRequest m_request; | ||
}; | ||
} // namespace ov::genai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "with_past_decoder.hpp" | ||
|
||
#include <regex> | ||
|
||
#include "utils.hpp" | ||
|
||
namespace { | ||
void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) { | ||
// source outputs: | ||
// present.0.decoder.key | ||
// present.0.decoder.value | ||
// present.0.encoder.key | ||
// present.0.encoder.value | ||
|
||
// dest inputs: | ||
// past_key_values.0.decoder.key | ||
// past_key_values.0.decoder.value | ||
// past_key_values.0.encoder.key | ||
// past_key_values.0.encoder.value | ||
|
||
for (auto& source_output : source.get_compiled_model().outputs()) { | ||
std::string source_output_name = source_output.get_any_name(); | ||
if (source_output_name.find("logits") != std::string::npos) { | ||
continue; | ||
} | ||
|
||
std::string with_past_input_name = | ||
std::regex_replace(source_output_name, std::regex("present"), "past_key_values"); | ||
|
||
auto kv_tensor = source.get_tensor(source_output_name); | ||
dest.set_tensor(with_past_input_name, ov::Tensor{kv_tensor}); | ||
} | ||
} | ||
} // namespace | ||
|
||
namespace ov::genai { | ||
WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties) { | ||
ov::Core core = utils::singleton_core(); | ||
|
||
auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties); | ||
utils::print_compiled_model_properties(compiled_model, "whisper decoder model"); | ||
m_request_decoder = compiled_model.create_infer_request(); | ||
|
||
compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties); | ||
utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model"); | ||
m_request_decoder_with_past = compiled_model.create_infer_request(); | ||
} | ||
|
||
std::pair<int64_t, float> WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) { | ||
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0); | ||
|
||
int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); | ||
|
||
reset_state(); | ||
|
||
return {output_token, infer_ms}; | ||
} | ||
|
||
std::pair<ov::Tensor, float> WhisperWithPastDecoder::decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) { | ||
const bool initial_step = cache_position == 0; | ||
ov::InferRequest& request = initial_step ? m_request_decoder : m_request_decoder_with_past; | ||
|
||
request.set_tensor("encoder_hidden_states", encoder_hidden_state); | ||
|
||
const ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data()); | ||
request.set_tensor("input_ids", input_ids_tensor); | ||
|
||
if (!initial_step) { | ||
ov::Tensor cache_position_tensor = request.get_tensor("cache_position"); | ||
cache_position_tensor.set_shape({1}); | ||
cache_position_tensor.data<int64_t>()[0] = cache_position; | ||
} | ||
|
||
const auto infer_start = std::chrono::steady_clock::now(); | ||
request.infer(); | ||
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); | ||
|
||
auto output_tensor = request.get_tensor("logits"); | ||
|
||
if (initial_step) { | ||
set_past_key_value(m_request_decoder, m_request_decoder_with_past); | ||
} else if (!m_decoder_with_past_kv_value_set) { | ||
set_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past); | ||
m_decoder_with_past_kv_value_set = true; | ||
} | ||
|
||
return {output_tensor, infer_ms}; | ||
} | ||
|
||
void WhisperWithPastDecoder::reset_state() { | ||
m_request_decoder_with_past.reset_state(); | ||
m_decoder_with_past_kv_value_set = false; | ||
} | ||
} // namespace ov::genai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "decoder.hpp" | ||
#include "openvino/runtime/core.hpp" | ||
|
||
namespace ov::genai { | ||
|
||
class WhisperWithPastDecoder : public WhisperDecoder { | ||
public: | ||
WhisperWithPastDecoder(const std::filesystem::path& models_path, | ||
const std::string& device, | ||
const ov::AnyMap& properties); | ||
|
||
std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state, | ||
const int64_t decoder_start_token_id) override; | ||
|
||
std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state, | ||
const std::vector<int64_t>& input_ids, | ||
const size_t cache_position) override; | ||
|
||
void reset_state() override; | ||
|
||
private: | ||
ov::InferRequest m_request_decoder; | ||
ov::InferRequest m_request_decoder_with_past; | ||
bool m_decoder_with_past_kv_value_set = false; | ||
}; | ||
|
||
} // namespace ov::genai |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ilya-lavrenov @Wovchena I want to add deprecation note for this ctor. I saw
OPENVINO_DEPRECATED
macros but I think it doesn't fit here as I want to warn user at runtime.Can I add just
std::cout << "[Warning] Whisper decoder with past deprecated ..."
? Does OV have logging utilities we can reuse? Or there is a better way?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have a better way than a message in runtime