Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Qwen2Audio] handle input ids expansion during processing #35534

Merged
merged 12 commits into from
Jan 7, 2025
Prev Previous commit
Next Next commit
expand input_ids
eustlb committed Jan 6, 2025
commit 4de1294bb0ce98e428aa30ad8a4cf92d6fb4bebc
23 changes: 22 additions & 1 deletion src/transformers/models/qwen2_audio/processing_qwen2_audio.py
Original file line number Diff line number Diff line change
@@ -91,7 +91,8 @@ def __call__(

if text is None:
raise ValueError("You need to specify either a `text` input to process.")
inputs = self.tokenizer(text, padding=padding, **kwargs)
elif isinstance(text, str):
text = [text]

if audios is not None:
audio_inputs = self.feature_extractor(
@@ -100,6 +101,26 @@ def __call__(
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
) # rename attention_mask to prevent conflicts later on

expanded_text = []
eustlb marked this conversation as resolved.
Show resolved Hide resolved
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()
for sample in text:
replace_str = []
while self.audio_token in sample:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
replace_str.append(self.audio_token * num_audio_tokens)
sample = sample.replace(self.audio_token, "<placeholder>", 1)

while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text
eustlb marked this conversation as resolved.
Show resolved Hide resolved

inputs = self.tokenizer(text, padding=padding, **kwargs)

if audios is not None:
inputs.update(audio_inputs)

return BatchFeature(data={**inputs})