Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a choice of how to end streaming from callback: STOP or CANCEL #1476

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions samples/cpp/chat_sample/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
std::function<bool(std::string)> streamer = [](std::string word) {

std::function<ov::genai::StreamerRunningStatus(std::string)> streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
// false means continue generation.
return false;

return ov::genai::StreamerRunningStatus::RUNNING;
};

pipe.start_chat();
Expand Down
6 changes: 2 additions & 4 deletions samples/python/chat_sample/chat_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
import openvino_genai


def streamer(subword):
def streamer(subword) -> openvino_genai.StreamerRunningStatus:
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False

return openvino_genai.StreamerRunningStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>

#include "openvino/genai/generation_config.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/visibility.hpp"

namespace ov::genai {
Expand All @@ -30,6 +31,9 @@ struct EncodedGenerationResult {

// Status of generation
GenerationStatus m_status = GenerationStatus::RUNNING;

// Status of streaming
StreamerRunningStatus m_streaming_status = ov::genai::StreamerRunningStatus::UNDEF;
};

enum class GenerationFinishReason {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<StreamerRunningStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
16 changes: 16 additions & 0 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@
#pragma once

#include "openvino/genai/tokenizer.hpp"
#include <variant>

namespace ov {
namespace genai {

enum class StreamerRunningStatus {
UNDEF = 0, // Streaming is not run
RUNNING = 1, // Continue to run of inference
STOP = 2, // Stop generation, keep history as is, KV cache includes last request and generated tokens
CANCEL = 3 // Stop generate, drop last prompt and all generated tokens from history, KV cache include history but last step
};

using CallbackTypeVariant = std::variant<bool, StreamerRunningStatus>;

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS StreamerBase {
protected:
StreamerRunningStatus streaming_finish_status = StreamerRunningStatus::UNDEF;
public:
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
Expand All @@ -22,6 +34,10 @@ class OPENVINO_GENAI_EXPORTS StreamerBase {
/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

virtual StreamerRunningStatus get_finish_streaming_reason() {
return streaming_finish_status;
}

virtual ~StreamerBase();
};

Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
},
[this](const std::function<ov::genai::StreamerRunningStatus(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);

Expand Down Expand Up @@ -346,6 +349,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);

if (streamer_ptr)
result.m_streaming_status = streamer_ptr->get_finish_streaming_reason();

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob();
Expand Down
6 changes: 5 additions & 1 deletion src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
generated.reserve(res.m_generation_ids.size());
for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) {
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
if (m_is_chat_conversation && 0 == idx) {
if (m_is_chat_conversation && 0 == idx && res.m_streaming_status != ov::genai::StreamerRunningStatus::CANCEL) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
Expand All @@ -77,6 +77,10 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
});
}

// if streaming was canceled, prompt/answer of current step shouldn't be presented in history, so let's remove prompt from history
if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_streaming_status == ov::genai::StreamerRunningStatus::CANCEL)
m_history.pop_back();

return decoded;
}
}
2 changes: 2 additions & 0 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ std::pair<ov::AnyMap, ov::genai::static_llm::ModelConfigDesc> split_model_descr(
std::pair<std::string, Any> streamer(StreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::shared_ptr<StreamerBase>>(*streamer_obj)};
} else if (auto streamer_obj = std::get_if<std::function<StreamerRunningStatus(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<StreamerRunningStatus(std::string)>>(*streamer_obj)};
} else {
auto callback = std::get<std::function<bool(std::string)>>(func);
return {utils::STREAMER_ARG_NAME, Any::make<std::function<bool(std::string)>>(callback)};
Expand Down
74 changes: 49 additions & 25 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const ov::genai::GenerationConfig& generation_config)
: LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
utils::apply_slice_before_matmul_transformation(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);
m_kv_history_manager.kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

ov::CompiledModel compiled_model;
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
Expand Down Expand Up @@ -86,6 +86,9 @@ DecodedResults StatefulLLMPipeline::generate(

TokenizedInputs encoded_input;

std::string prev_templated_chat_history(m_templated_chat_history);
std::vector<int64_t> prev_tokenized_chat_history(m_tokenized_chat_history);

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
encoded_input = m_tokenizer.encode(*input_vector);
Expand All @@ -104,7 +107,7 @@ DecodedResults StatefulLLMPipeline::generate(

m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
Expand All @@ -116,21 +119,24 @@ DecodedResults StatefulLLMPipeline::generate(
if (!m_tokenized_chat_history.empty()) {
std::set<int64_t> stop_tokens = config.stop_token_ids;
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = trusted_history_length == SIZE_MAX;
}

if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) {
// does_kv_cache_need_to_update will be true here if beam search is activated
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_history_cache_need_to_update()) {
// does_history_cache_need_to_update will be true here if beam search is activated
// in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager
if (m_kv_history_manager.does_kv_cache_need_to_update()) {
if (m_kv_history_manager.does_history_cache_need_to_update()) {
trusted_history_length = m_kv_history_manager.trusted_history_length;
} else {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
size_t num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
// if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it
m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;
num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;

// if streaming was used and canceled on prev step, num_tokens_to_remove_from_kv_cache could be already set and it will be bigger as include answer + prompt
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = num_tokens_to_remove_from_kv_cache > m_kv_history_manager.num_tokens_to_remove_from_kv_cache ?
num_tokens_to_remove_from_kv_cache : m_kv_history_manager.num_tokens_to_remove_from_kv_cache;
}

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
Expand Down Expand Up @@ -169,11 +175,19 @@ DecodedResults StatefulLLMPipeline::generate(
auto decode_stop_time = std::chrono::steady_clock::now();

if (is_chat_conversation) {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
auto answer = decoded_results.texts[0];
m_templated_chat_history.append(answer);
m_history.push_back({{"role", "assistant"}, {"content", answer}});
if (m_chat_generation_finish_status == ov::genai::StreamerRunningStatus::CANCEL) {
// If chat generation process was canceled by user, let's rollback to previous state of history
m_history.pop_back();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += m_tokenized_chat_history.size() - prev_tokenized_chat_history.size();
m_templated_chat_history = prev_templated_chat_history;
m_tokenized_chat_history = prev_tokenized_chat_history;
} else {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
auto answer = decoded_results.texts[0];
m_templated_chat_history.append(answer);
m_history.push_back({{"role", "assistant"}, {"content", answer}});
}
}

// generate_durations
Expand Down Expand Up @@ -218,6 +232,8 @@ EncodedResults StatefulLLMPipeline::generate(
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));

size_t real_input_ids_size = input_ids.get_shape().at(1);

// Tail of previous output in chat mode is missing in KV cache.
if (m_last_disappeared_token.has_value()) {
attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1);
Expand All @@ -239,7 +255,9 @@ EncodedResults StatefulLLMPipeline::generate(
streamer_ptr = nullptr;
} else if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
} else if (auto callback = std::get_if<std::function<ov::genai::StreamerRunningStatus(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

Expand All @@ -254,7 +272,8 @@ EncodedResults StatefulLLMPipeline::generate(
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller);
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache,
m_kv_history_manager.kv_cache_seq_length_axis, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
Expand Down Expand Up @@ -292,8 +311,7 @@ EncodedResults StatefulLLMPipeline::generate(
m_adapter_controller->apply(m_model_runner, config.adapters);
}

if (is_chat_conversation && !m_trust_encoded_history) {
m_trust_encoded_history = true;
if (is_chat_conversation) {
m_kv_history_manager.reset();
}

Expand Down Expand Up @@ -321,9 +339,11 @@ EncodedResults StatefulLLMPipeline::generate(
m_sampler.set_seed(config.rng_seed);
}

ov::genai::EncodedResults result;
std::tie(result, m_last_disappeared_token) = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
ov::genai::utils::GenerationFinishInfo finish_info = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, m_sampler, requests, position_ids, std::nullopt);
ov::genai::EncodedResults result = finish_info.results;
m_last_disappeared_token = finish_info.probably_disappeared_token;
m_chat_generation_finish_status = finish_info.streaming_finish_status;

if (is_chat_conversation) {
// force remove from kv_cache last answer
Expand All @@ -332,15 +352,21 @@ EncodedResults StatefulLLMPipeline::generate(
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
if (m_chat_generation_finish_status == ov::genai::StreamerRunningStatus::CANCEL) {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;

if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_tokenized_chat_history.resize(m_tokenized_chat_history.size() - real_input_ids_size);
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += real_input_ids_size;
}
} else {
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
}
} else {
reset_kv_state();
m_last_disappeared_token = std::nullopt;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -354,7 +380,6 @@ EncodedResults StatefulLLMPipeline::generate(

void StatefulLLMPipeline::start_chat(const std::string& system_message) {
is_chat_conversation = true;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
Expand Down Expand Up @@ -387,7 +412,6 @@ void StatefulLLMPipeline::reset_kv_state() {

void StatefulLLMPipeline::finish_chat() {
is_chat_conversation = false;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/llm_pipeline_stateful.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0};
size_t m_kv_cache_seq_length_axis = 2;
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0, 2};
// Finish reason of last generation for chat scenario
ov::genai::StreamerRunningStatus m_chat_generation_finish_status = ov::genai::StreamerRunningStatus::UNDEF;

void reset_kv_state();
public:
Expand Down
Loading
Loading