Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Medusa performance degrades with batch size larger than 1 #2482

Open
SoundProvider opened this issue Nov 22, 2024 · 2 comments
Open

Medusa performance degrades with batch size larger than 1 #2482

SoundProvider opened this issue Nov 22, 2024 · 2 comments
Labels
Performance Issue about performance number

Comments

@SoundProvider
Copy link

SoundProvider commented Nov 22, 2024

I'm trying to use medusa with trt-llm, referencing this page

It's working fine with vicuna 7B and its medusa heads, as reference in the example page.

In the example, it's stated that Note: Increasing the batch size may have a negative impact on performance
My understanding is that, when the batch size increases, each sequence should wait for the other sequences to reach its position, resulting performance degradation.

But when I tested with vicuna 7B, the performance still dropped with 4 batch, each sequence using the same input. This is contradicting from my understanding.

Image
I tested batch size variation with same inputs(4batch with same inputs)

What would be the reason?? It would be really nice if someone could explain.

Thank you

@hello-11
Copy link
Collaborator

@SoundProvider could you tell me the method of your performance evaluations?

@hello-11 hello-11 added the Performance Issue about performance number label Nov 22, 2024
@SoundProvider
Copy link
Author

SoundProvider commented Nov 22, 2024

@hello-11 hello.
I used the run script in the medusa example folder

python /app/tensorrt_llm/examples/run.py --engine_dir /app/models/medusa_test_3b/tensorrt_llm/4-gpu \
                                            --tokenizer_dir /app/models/vicuna-33b-v1.3 \
                                            --max_output_len=500 \
                                            --medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]" \
                                            --temperature 1.0 \
                                            --input_text "Once upon" \
                                            --run_profiling

I just measured the .generate time

for _ in range(10):
    s_time = time.time()
    outputs = runner.generate(
        batch_input_ids=decoder_input_ids
        if is_enc_dec else batch_input_ids,
        encoder_input_ids=encoder_input_ids if is_enc_dec else None,
        encoder_input_features=encoder_input_features
        if is_enc_dec else None,
        encoder_output_lengths=encoder_output_lengths
        if is_enc_dec else None,
        max_new_tokens=args.max_output_len,
        max_attention_window_size=args.max_attention_window_size,
        sink_token_length=args.sink_token_length,
        end_id=end_id,
        pad_id=pad_id,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        num_beams=args.num_beams,
        num_return_sequences=args.num_return_sequences,
        length_penalty=args.length_penalty,
        early_stopping=args.early_stopping,
        repetition_penalty=args.repetition_penalty,
        presence_penalty=args.presence_penalty,
        frequency_penalty=args.frequency_penalty,
        stop_words_list=stop_words_list,
        bad_words_list=bad_words_list,
        output_cum_log_probs=(args.output_cum_log_probs_npy != None),
        output_log_probs=(args.output_log_probs_npy != None),
        random_seed=args.random_seed,
        lora_uids=args.lora_task_uids,
        prompt_table=args.prompt_table_path,
        prompt_tasks=args.prompt_tasks,
        streaming=args.streaming,
        output_sequence_lengths=True,
        no_repeat_ngram_size=args.no_repeat_ngram_size,
        return_dict=True,
        medusa_choices=args.medusa_choices,
        return_all_generated_tokens=args.return_all_generated_tokens,
        input_token_extra_ids=input_token_extra_ids)
    e_time = time.time()
    
    times.append( e_time - s_time )

print("[TIME] :: ", np.average(times))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Performance Issue about performance number
Projects
None yet
Development

No branches or pull requests

2 participants