Skip to content

Commit

Permalink
Add a choice of how to end streaming from callback: STOP or CANCEL
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 7, 2025
1 parent 34dc469 commit 1c55273
Show file tree
Hide file tree
Showing 16 changed files with 158 additions and 56 deletions.
4 changes: 4 additions & 0 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>

#include "openvino/genai/generation_config.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/visibility.hpp"

namespace ov::genai {
Expand All @@ -30,6 +31,9 @@ struct EncodedGenerationResult {

// Status of generation
GenerationStatus m_status = GenerationStatus::RUNNING;

// Status of streaming
CallbacWorkStatus m_streaming_status = ov::genai::CallbacWorkStatus::UNDEF;
};

enum class GenerationFinishReason {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<CallbacWorkStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
16 changes: 16 additions & 0 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@
#pragma once

#include "openvino/genai/tokenizer.hpp"
#include <variant>

namespace ov {
namespace genai {

enum class CallbacWorkStatus {
UNDEF = 0, // Streaming is not run
RUNNING = 1, // Continue to run of inference
STOP = 2, // Stop generate, keep hitory as is, KV state include last prompt and generated tokens at the end
CANCEL = 3 // Stop generate, drop last prompt and all generated tokens from history, KV state include history exept last step
};

using CallbackTypeVariant = std::variant<bool, CallbacWorkStatus>;

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS StreamerBase {
protected:
CallbacWorkStatus streaming_finish_status = CallbacWorkStatus::UNDEF;
public:
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
Expand All @@ -22,6 +34,10 @@ class OPENVINO_GENAI_EXPORTS StreamerBase {
/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

virtual CallbacWorkStatus get_finish_streaming_reason() {
return streaming_finish_status;
}

virtual ~StreamerBase();
};

Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
},
[this](const std::function<CallbacWorkStatus(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);

Expand Down Expand Up @@ -345,6 +348,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);

if (streamer_ptr)
result.m_streaming_status = streamer_ptr->get_finish_streaming_reason();

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob();
Expand Down
6 changes: 5 additions & 1 deletion src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
generated.reserve(res.m_generation_ids.size());
for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) {
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
if (m_is_chat_conversation && 0 == idx) {
if (m_is_chat_conversation && 0 == idx && res.m_streaming_status != ov::genai::CallbacWorkStatus::CANCEL) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
Expand All @@ -77,6 +77,10 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
});
}

// if streaming was canceled, prompt/answer of current step shouldn't be presented in history, so let's remove prompt from history
if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_streaming_status == ov::genai::CallbacWorkStatus::CANCEL)
m_history.pop_back();

return decoded;
}
}
75 changes: 51 additions & 24 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "text_callback_streamer.hpp"
#include "utils.hpp"

#include "debug_utils.hpp"

namespace ov::genai {

StatefulLLMPipeline::StatefulLLMPipeline(
Expand Down Expand Up @@ -39,7 +41,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const ov::genai::GenerationConfig& generation_config)
: LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
utils::slice_matmul_stateful_model(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);
m_kv_history_manager.kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

ov::CompiledModel compiled_model;
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
Expand Down Expand Up @@ -86,6 +88,9 @@ DecodedResults StatefulLLMPipeline::generate(

TokenizedInputs encoded_input;

std::string prev_templated_chat_history(m_templated_chat_history);
std::vector<int64_t> prev_tokenized_chat_history(m_tokenized_chat_history);

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
encoded_input = m_tokenizer.encode(*input_vector);
Expand All @@ -104,7 +109,7 @@ DecodedResults StatefulLLMPipeline::generate(

m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
Expand All @@ -116,21 +121,24 @@ DecodedResults StatefulLLMPipeline::generate(
if (!m_tokenized_chat_history.empty()) {
std::set<int64_t> stop_tokens = config.stop_token_ids;
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = trusted_history_length == SIZE_MAX;
}

if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) {
// does_kv_cache_need_to_update will be true here if beam search is activated
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_history_cache_need_to_update()) {
// does_history_cache_need_to_update will be true here if beam search is activated
// in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager
if (m_kv_history_manager.does_kv_cache_need_to_update()) {
if (m_kv_history_manager.does_history_cache_need_to_update()) {
trusted_history_length = m_kv_history_manager.trusted_history_length;
} else {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
size_t num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
// if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it
m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;
num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;

// if streaming was used and canceled on prev step, num_tokens_to_remove_from_kv_cache could be already set and it will be bigger as include answer + prompt
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = num_tokens_to_remove_from_kv_cache > m_kv_history_manager.num_tokens_to_remove_from_kv_cache ?
num_tokens_to_remove_from_kv_cache : m_kv_history_manager.num_tokens_to_remove_from_kv_cache;
}

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
Expand Down Expand Up @@ -169,11 +177,19 @@ DecodedResults StatefulLLMPipeline::generate(
auto decode_stop_time = std::chrono::steady_clock::now();

if (is_chat_conversation) {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
auto answer = decoded_results.texts[0];
m_templated_chat_history.append(answer);
m_history.push_back({{"role", "assistant"}, {"content", answer}});
if (m_chat_generation_finish_status == ov::genai::CallbacWorkStatus::CANCEL) {
// If chat generation process was canceled by user, let's rallback to previous state of history
m_history.pop_back();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += m_tokenized_chat_history.size() - prev_tokenized_chat_history.size();
m_templated_chat_history = prev_templated_chat_history;
m_tokenized_chat_history = prev_tokenized_chat_history;
} else {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
auto answer = decoded_results.texts[0];
m_templated_chat_history.append(answer);
m_history.push_back({{"role", "assistant"}, {"content", answer}});
}
}

// generate_durations
Expand Down Expand Up @@ -218,6 +234,8 @@ EncodedResults StatefulLLMPipeline::generate(
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));

size_t real_input_ids_size = input_ids.get_shape().at(1);

// Tail of previous output in chat mode is missing in KV cache.
if (m_last_disappeared_token.has_value()) {
attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1);
Expand All @@ -241,6 +259,8 @@ EncodedResults StatefulLLMPipeline::generate(
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
} else if (auto callback = std::get_if<std::function<ov::genai::CallbacWorkStatus(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

auto batch_size = input_ids.get_shape().at(0);
Expand All @@ -254,7 +274,8 @@ EncodedResults StatefulLLMPipeline::generate(
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller);
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache,
m_kv_history_manager.kv_cache_seq_length_axis, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
Expand Down Expand Up @@ -292,8 +313,7 @@ EncodedResults StatefulLLMPipeline::generate(
m_adapter_controller->apply(m_model_runner, config.adapters);
}

if (is_chat_conversation && !m_trust_encoded_history) {
m_trust_encoded_history = true;
if (is_chat_conversation) {
m_kv_history_manager.reset();
}

Expand Down Expand Up @@ -321,26 +341,35 @@ EncodedResults StatefulLLMPipeline::generate(
m_sampler.set_seed(config.rng_seed);
}

ov::genai::EncodedResults result;
std::tie(result, m_last_disappeared_token) = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
ov::genai::utils::GenerationFinishInfo finish_info = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, m_sampler, requests, position_ids, std::nullopt);

ov::genai::EncodedResults result = finish_info.results;
m_last_disappeared_token = finish_info.probably_disappeared_token;
m_chat_generation_finish_status = finish_info.streaming_finish_status;

if (is_chat_conversation) {
// force remove from kv_cache last answer
if (config.is_beam_search() && m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_kv_history_manager.trusted_history_length = m_tokenized_chat_history.size();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
if (m_chat_generation_finish_status == ov::genai::CallbacWorkStatus::CANCEL) {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;

if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_tokenized_chat_history.resize(m_tokenized_chat_history.size() - real_input_ids_size);
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += real_input_ids_size;
}
} else {
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
}
} else {
reset_kv_state();
m_last_disappeared_token = std::nullopt;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -354,7 +383,6 @@ EncodedResults StatefulLLMPipeline::generate(

void StatefulLLMPipeline::start_chat(const std::string& system_message) {
is_chat_conversation = true;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
Expand Down Expand Up @@ -387,7 +415,6 @@ void StatefulLLMPipeline::reset_kv_state() {

void StatefulLLMPipeline::finish_chat() {
is_chat_conversation = false;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/llm_pipeline_stateful.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0};
size_t m_kv_cache_seq_length_axis = 2;
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0, 2};
// Finish reason of last generation for chat scenario
ov::genai::CallbacWorkStatus m_chat_generation_finish_status = ov::genai::CallbacWorkStatus::UNDEF;

void reset_kv_state();
public:
Expand Down
19 changes: 10 additions & 9 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "debug_utils.hpp"
#include "lm_encoding.hpp"
#include "openvino/genai/perf_metrics.hpp"
#include "openvino/genai/streamer_base.hpp"


namespace ov {
Expand Down Expand Up @@ -50,7 +51,7 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
}


std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(
ov::InferRequest& m_llm,
const ov::Tensor& input_ids,
const ov::Tensor& attention_mask,
Expand Down Expand Up @@ -92,8 +93,8 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(

// Initialize results and performance metrics.

EncodedResults results;
auto& raw_perf_counters = results.perf_metrics.raw_metrics;
ov::genai::utils::GenerationFinishInfo finish_info;
auto& raw_perf_counters = finish_info.results.perf_metrics.raw_metrics;
raw_perf_counters.m_inference_durations = {{ MicroSeconds(0.0f) }};

// Initialize inputs
Expand Down Expand Up @@ -211,6 +212,7 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(

if (streamer_ptr) { // push streamer's cache
streamer_ptr->end();
finish_info.streaming_finish_status = streamer_ptr->get_finish_streaming_reason();
}

for (auto& sequence_group : sequence_groups) {
Expand All @@ -222,20 +224,19 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
const auto & sequence = sequences[seq_id];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob();

results.tokens.push_back(sequence->get_generated_ids());
results.scores.push_back(score);
finish_info.results.tokens.push_back(sequence->get_generated_ids());
finish_info.results.scores.push_back(score);
}
}

for (SequenceGroup::Ptr sequence_group : sequence_groups)
sampler.clear_request_info(sequence_group->get_request_id());

// it is not saved in KV cache, we need to add it for some cases
std::optional<int64_t> last_token_of_best_sequence = std::nullopt;
// last generated token is not saved in KV cache, we need to add it for some cases
if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_dropped())
last_token_of_best_sequence = results.tokens[0].back();
finish_info.probably_disappeared_token = finish_info.results.tokens[0].back();

return {results, last_token_of_best_sequence};
return finish_info;
}

} // namespace genai
Expand Down
Loading

0 comments on commit 1c55273

Please sign in to comment.