Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion failed: Must set crossKvCacheFraction for encoder-decoder model #2419

Open
2 of 4 tasks
Saeedmatt3r opened this issue Nov 6, 2024 · 2 comments
Open
2 of 4 tasks
Assignees
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@Saeedmatt3r
Copy link

Saeedmatt3r commented Nov 6, 2024

System Info

GPU: A10
Base Image: FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
Tensorrt-llm:

  • 0.12.0 : It's working, but I can't use it because of a version mismatch in TRT and trt-llm-backend
  • 0.13.0: It's working, but I can't use it because of a version mismatch in TRT and trt-llm-backend
  • 0.14.0: not working: Assertion failed: Must set crossKvCacheFraction for encoder-decoder model
  • 0.15.0.dev2024110500 : not working

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce the problem:

FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04

RUN apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev

RUN pip3 install tensorrt_llm==0.14.0 -U --pre --extra-index-url https://pypi.nvidia.com

then by running the official whisper example:

INFERENCE_PRECISION=float16
WEIGHT_ONLY_PRECISION=int8
MAX_BEAM_WIDTH=4
MAX_BATCH_SIZE=8
checkpoint_dir=weights/whisper_large_v3_weights_${WEIGHT_ONLY_PRECISION}
output_dir=weights/whisper_large_v3_${WEIGHT_ONLY_PRECISION}

# Convert the large-v3 model weights into TensorRT-LLM format.
python3 convert_checkpoint.py \
                --use_weight_only \
                --weight_only_precision $WEIGHT_ONLY_PRECISION \
                --output_dir $checkpoint_dir

# Build the large-v3 model using trtllm-build
trtllm-build  --checkpoint_dir ${checkpoint_dir}/encoder \
              --output_dir ${output_dir}/encoder \
              --moe_plugin disable \
              --enable_xqa disable \
              --max_batch_size ${MAX_BATCH_SIZE} \
              --gemm_plugin disable \
              --bert_attention_plugin ${INFERENCE_PRECISION} \
              --max_input_len 3000 --max_seq_len=3000

trtllm-build  --checkpoint_dir ${checkpoint_dir}/decoder \
              --output_dir ${output_dir}/decoder \
              --moe_plugin disable \
              --enable_xqa disable \
              --max_beam_width ${MAX_BEAM_WIDTH} \
              --max_batch_size ${MAX_BATCH_SIZE} \
              --max_seq_len 114 \
              --max_input_len 14 \
              --max_encoder_input_len 3000 \
              --gemm_plugin ${INFERENCE_PRECISION} \
              --bert_attention_plugin ${INFERENCE_PRECISION} \
              --gpt_attention_plugin ${INFERENCE_PRECISION}

python3 run.py --engine_dir "$output_dir" --dataset hf-internal-testing/librispeech_asr_dummy --num_beams $MAX_BEAM_WIDTH --batch_size $MAX_BATCH_SIZE --enable_warmup --name librispeech_dummy_large_v3

Expected behavior

It should run on the dataset without any problem:

actual behavior

[TensorRT-LLM] TensorRT-LLM version: 0.14.0
[TensorRT-LLM][INFO] Engine version 0.14.0 found in the config file, assuming engine(s) built by new builder API.
[TensorRT-LLM][INFO] Setting encoder max input length and hidden size for accepting visual features.
[TensorRT-LLM][INFO] Engine version 0.14.0 found in the config file, assuming engine(s) built by new builder API.
[TensorRT-LLM][INFO] Engine version 0.14.0 found in the config file, assuming engine(s) built by new builder API.
[TensorRT-LLM][INFO] Setting encoder max input length and hidden size for accepting visual features.
[TensorRT-LLM][INFO] MPI size: 1, MPI local size: 1, rank: 0
[TensorRT-LLM][INFO] Engine version 0.14.0 found in the config file, assuming engine(s) built by new builder API.
[TensorRT-LLM][INFO] Setting encoder max input length and hidden size for accepting visual features.
[TensorRT-LLM][INFO] Refreshed the MPI local session
[TensorRT-LLM][INFO] Engine version 0.14.0 found in the config file, assuming engine(s) built by new builder API.
[TensorRT-LLM][INFO] MPI size: 1, MPI local size: 1, rank: 0
[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 8
[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 8
[TensorRT-LLM][INFO] TRTGptModel maxBeamWidth: 1
[TensorRT-LLM][INFO] TRTGptModel maxSequenceLen: 3000
[TensorRT-LLM][INFO] TRTGptModel maxDraftLen: 0
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: (3000) * 32
[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0
[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 1
[TensorRT-LLM][INFO] TRTGptModel maxNumTokens: 8192
[TensorRT-LLM][INFO] TRTGptModel maxInputLen: 2999 = min(maxSequenceLen - 1, maxNumTokens) since context FMHA and usePackedInput are enabled
[TensorRT-LLM][INFO] TRTGptModel If model type is encoder, maxInputLen would be reset in trtEncoderModel to maxInputLen: min(maxSequenceLen, maxNumTokens).
[TensorRT-LLM][INFO] Capacity Scheduler Policy: GUARANTEED_NO_EVICT
[TensorRT-LLM][INFO] Context Chunking Scheduler Policy: None
[TensorRT-LLM][INFO] Loaded engine size: 622 MiB
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 292.97 MiB for execution context memory.
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 615 (MiB)
[TensorRT-LLM][INFO] TRTEncoderModel mMaxInputLen: reset to 3000 from build config.
[TensorRT-LLM][INFO] MPI size: 1, MPI local size: 1, rank: 0
[TensorRT-LLM][INFO] Rank 0 is using GPU 0
[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 8
[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 8
[TensorRT-LLM][INFO] TRTGptModel maxBeamWidth: 4
[TensorRT-LLM][INFO] TRTGptModel maxSequenceLen: 114
[TensorRT-LLM][INFO] TRTGptModel maxDraftLen: 0
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: (114) * 32
[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0
[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 1
[TensorRT-LLM][INFO] TRTGptModel maxNumTokens: 912
[TensorRT-LLM][INFO] TRTGptModel maxInputLen: 113 = min(maxSequenceLen - 1, maxNumTokens) since context FMHA and usePackedInput are enabled
[TensorRT-LLM][INFO] TRTGptModel If model type is encoder, maxInputLen would be reset in trtEncoderModel to maxInputLen: min(maxSequenceLen, maxNumTokens).
[TensorRT-LLM][INFO] Capacity Scheduler Policy: GUARANTEED_NO_EVICT
[TensorRT-LLM][INFO] Context Chunking Scheduler Policy: None
[TensorRT-LLM][INFO] The logger passed into createInferRuntime differs from one already provided for an existing builder, runtime, or refitter. Uses of the global logger, returned by nvinfer1::getLogger(), will return the existing value.
[TensorRT-LLM][INFO] Loaded engine size: 1066 MiB
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 833.73 MiB for execution context memory.
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 1672 (MiB)
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 3.54 MB GPU memory for runtime buffers.
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 16.82 MB GPU memory for decoder.
[TensorRT-LLM][INFO] Memory usage when calculating max tokens in paged kv cache: total: 21.98 GiB, available: 18.88 GiB
[TensorRT-LLM][INFO] Number of blocks in KV cache primary pool: 1740
[TensorRT-LLM][INFO] Number of blocks in KV cache secondary pool: 0, onboard blocks to primary memory before reuse: true
Traceback (most recent call last):
  File "/app/TensorRT-LLM/examples/whisper/run.py", line 479, in <module>
    model = WhisperTRTLLM(args.engine_dir, args.debug, args.assets_dir,
  File "/app/TensorRT-LLM/examples/whisper/run.py", line 327, in __init__
    self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/model_runner_cpp.py", line 206, in from_dir
    executor = trtllm.Executor(
RuntimeError: [TensorRT-LLM][ERROR] Assertion failed: Must set crossKvCacheFraction for encoder-decoder model (/home/jenkins/agent/workspace/LLM/release-0.14/L0_PostMerge/llm/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp:207)
1       0x7493ec8cccc7 tensorrt_llm::common::throwRuntimeError(char const*, int, std::string const&) + 82
2       0x7493ec914e4a /usr/local/lib/python3.10/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x769e4a) [0x7493ec914e4a]
3       0x7493eea3bac4 tensorrt_llm::batch_manager::TrtGptModelFactory::create(tensorrt_llm::runtime::RawEngine const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&, tensorrt_llm::batch_manager::TrtGptModelType, tensorrt_llm::batch_manager::TrtGptModelOptionalParams const&) + 756
4       0x7493eeac1d7d tensorrt_llm::executor::Executor::Impl::createModel(tensorrt_llm::runtime::RawEngine const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&, tensorrt_llm::executor::ExecutorConfig const&) + 125
5       0x7493eeac233a tensorrt_llm::executor::Executor::Impl::loadModel(std::optional<std::filesystem::path> const&, std::optional<std::basic_string_view<unsigned char, std::char_traits<unsigned char> > > const&, tensorrt_llm::runtime::GptJsonConfig const&, tensorrt_llm::executor::ExecutorConfig const&, bool, std::optional<std::map<std::string, tensorrt_llm::executor::Tensor, std::less<std::string>, std::allocator<std::pair<std::string const, tensorrt_llm::executor::Tensor> > > > const&) + 954
6       0x7493eeac3377 tensorrt_llm::executor::Executor::Impl::Impl(std::filesystem::path const&, std::optional<std::filesystem::path> const&, tensorrt_llm::executor::ModelType, tensorrt_llm::executor::ExecutorConfig const&) + 2135
7       0x7493eeaaeb33 tensorrt_llm::executor::Executor::Executor(std::filesystem::path const&, std::filesystem::path const&, tensorrt_llm::executor::ModelType, tensorrt_llm::executor::ExecutorConfig const&) + 99
8       0x749518615609 /usr/local/lib/python3.10/dist-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0xdd609) [0x749518615609]
9       0x749518598a8f /usr/local/lib/python3.10/dist-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x60a8f) [0x749518598a8f]
10      0x61696fec6b2e python3(+0x15cb2e) [0x61696fec6b2e]
11      0x61696febd2db _PyObject_MakeTpCall + 603
12      0x61696fed56b0 python3(+0x16b6b0) [0x61696fed56b0]
13      0x61696fed1ad7 python3(+0x167ad7) [0x61696fed1ad7]
14      0x61696febd68b python3(+0x15368b) [0x61696febd68b]
15      0x7495185928bb /usr/local/lib/python3.10/dist-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x5a8bb) [0x7495185928bb]
16      0x61696febd2db _PyObject_MakeTpCall + 603
17      0x61696feb5d27 _PyEval_EvalFrameDefault + 27415
18      0x61696fed5281 python3(+0x16b281) [0x61696fed5281]
19      0x61696fed5f22 PyObject_Call + 290
20      0x61696feb1a6e _PyEval_EvalFrameDefault + 10334
21      0x61696febc474 _PyObject_FastCallDictTstate + 196
22      0x61696fed14b4 python3(+0x1674b4) [0x61696fed14b4]
23      0x61696febd27c _PyObject_MakeTpCall + 508
24      0x61696feb56e6 _PyEval_EvalFrameDefault + 25814
25      0x61696feac016 python3(+0x142016) [0x61696feac016]
26      0x61696ffa18b6 PyEval_EvalCode + 134
27      0x61696ffcc918 python3(+0x262918) [0x61696ffcc918]
28      0x61696ffc61db python3(+0x25c1db) [0x61696ffc61db]
29      0x61696ffcc665 python3(+0x262665) [0x61696ffcc665]
30      0x61696ffcbb48 _PyRun_SimpleFileObject + 424
31      0x61696ffcb793 _PyRun_AnyFileObject + 67
32      0x61696ffbe2ce Py_RunMain + 702
33      0x61696ff9470d Py_BytesMain + 45
34      0x7496408cfd90 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x29d90) [0x7496408cfd90]
35      0x7496408cfe40 __libc_start_main + 128
36      0x61696ff94605 _start + 37

Using the latest available pip package(0.15.0.dev2024110500)

+ trtllm-build --checkpoint_dir weights/whisper_large_v3_weights_int8/encoder --output_dir weights/whisper_large_v3_int8/encoder --moe_plugin disable --enable_xqa disable --max_batch_size 8 --gemm_plugin disable --bert_attention_plugin float16 --max_input_len 3000 --max_seq_len=3000
[TensorRT-LLM] TensorRT-LLM version: 0.15.0.dev2024110500
[11/06/2024-12:28:13] [TRT-LLM] [I] Set bert_attention_plugin to float16.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set gpt_attention_plugin to auto.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set gemm_plugin to None.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set gemm_swiglu_plugin to None.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set fp8_rowwise_gemm_plugin to None.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set nccl_plugin to auto.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set lora_plugin to None.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set moe_plugin to None.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set low_latency_gemm_plugin to None.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set low_latency_gemm_swiglu_plugin to None.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set context_fmha to True.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set bert_context_fmha_fp32_acc to False.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set remove_input_padding to True.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set reduce_fusion to False.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set enable_xqa to False.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set tokens_per_block to 64.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set use_paged_context_fmha to False.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set use_fp8_context_fmha to False.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set multiple_profiles to False.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set paged_state to True.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set streamingllm to False.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set use_fused_mlp to True.
[11/06/2024-12:28:13] [TRT-LLM] [I] Set pp_reduce_scatter to False.
[11/06/2024-12:28:13] [TRT-LLM] [W] Implicitly setting PretrainedConfig.n_mels = 128
[11/06/2024-12:28:13] [TRT-LLM] [W] Implicitly setting PretrainedConfig.n_audio_ctx = 1500
[11/06/2024-12:28:13] [TRT-LLM] [W] Implicitly setting PretrainedConfig.num_languages = 100
Traceback (most recent call last):
  File "/usr/local/bin/trtllm-build", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 602, in main
    parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 425, in parallel_build
    passed = build_and_save(rank, rank % workers, ckpt_dir,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 390, in build_and_save
    engine = build_model(build_config,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 360, in build_model
    model = model_cls.from_checkpoint(ckpt_dir, config=rank_config)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/modeling_utils.py", line 635, in from_checkpoint
    model = cls(config)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/modeling_utils.py", line 564, in __call__
    obj = type.__call__(cls, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/enc_dec/model.py", line 1939, in __init__
    self.position_embedding = Embedding(self.config.max_position_embeddings,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/layers/embedding.py", line 63, in __init__
    shape = (math.ceil(self.num_embeddings / self.tp_size),
TypeError: unsupported operand type(s) for /: 'NoneType' and 'int'

additional notes

Also I checked the trt-llm-backend for the whisper example and also that was not working, with the following error:

c_python_backend_utils.TritonModelException: [TensorRT-LLM][ERROR] Assertion failed: input tokens tensor not provided (/tmp/tritonbuild/tensorrtllm/inflight_batcher_llm/src/utils.cc:107)
@Saeedmatt3r Saeedmatt3r added the bug Something isn't working label Nov 6, 2024
@hello-11 hello-11 added triaged Issue has been triaged by maintainers Investigating labels Nov 6, 2024
@yuekaizhang
Copy link

@Saeedmatt3r Thanks for reporting the issue. The fix would be synced to github next week. For a quick fix, you need to modify here https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/whisper/run.py#L370.

Image

@Saeedmatt3r
Copy link
Author

Saeedmatt3r commented Nov 7, 2024

@yuekaizhang Thanks, to be honest, I've done that actually, but Just wanted to report the issues in 0.14 and 0.15, also I think the official whisper on trt-llm-backend is also not working as expected, I used 0.15 for engine creation and it was not working. I will try to create another ticket in the repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants