Skip to content

Commit

Permalink
Whisper: address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae committed Jan 7, 2025
1 parent dfba871 commit 7e8bbfe
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
// attention_mask [1, 1, 1, 0]
auto input_ids_data = input_ids_tensor.data<int32_t>();
std::copy(init_ids.begin(), init_ids.end(), input_ids_data);
// std::fill(input_ids_data + init_ids.size(),
// input_ids_data + input_ids_tensor.get_size(),
// static_cast<int32_t>(pad_token));

auto attention_mask_data = attention_mask_tensor.data<ov::float16>();
std::fill_n(attention_mask_data, init_ids.size(), 1u);
Expand Down Expand Up @@ -489,7 +486,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
preprocessor.input("attention_mask").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type();
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32);
} else if (tensor.get_any_name().find("past_key_values") != std::string::npos) {
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type();
Expand Down

0 comments on commit 7e8bbfe

Please sign in to comment.