You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
(venv) admin@inference_l4x4_admin:~/tensorrt$ ./build_qwenVL.sh --verbose
Authorization required, but no authorization protocol specified
Authorization required, but no authorization protocol specified
Authorization required, but no authorization protocol specified
[TensorRT-LLM] TensorRT-LLM version: 0.16.0.dev2024111900
[11/22/2024-06:24:54] [TRT-LLM] [I] Set bert_attention_plugin to auto.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set gpt_attention_plugin to float16.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set gemm_plugin to float16.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set gemm_swiglu_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set fp8_rowwise_gemm_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set nccl_plugin to auto.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set lora_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set moe_plugin to auto.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set low_latency_gemm_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set low_latency_gemm_swiglu_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set context_fmha to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set bert_context_fmha_fp32_acc to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set remove_input_padding to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set reduce_fusion to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set enable_xqa to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set tokens_per_block to 64.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set use_paged_context_fmha to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set use_fp8_context_fmha to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set multiple_profiles to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set paged_state to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set streamingllm to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set use_fused_mlp to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set pp_reduce_scatter to False.
[11/22/2024-06:24:54] [TRT-LLM] [W] Implicitly setting QWenConfig.qwen_type = qwen
[11/22/2024-06:24:54] [TRT-LLM] [W] Implicitly setting QWenConfig.moe_intermediate_size = 0
[11/22/2024-06:24:54] [TRT-LLM] [W] Implicitly setting QWenConfig.moe_shared_expert_intermediate_size = 0
[11/22/2024-06:24:54] [TRT-LLM] [W] Implicitly setting QWenConfig.tie_word_embeddings = False
[11/22/2024-06:24:54] [TRT-LLM] [I] Set dtype to bfloat16.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set paged_kv_cache to True.
[11/22/2024-06:24:54] [TRT-LLM] [W] Overriding paged_state to False
[11/22/2024-06:24:54] [TRT-LLM] [I] Set paged_state to False.
[11/22/2024-06:24:54] [TRT-LLM] [W] remove_input_padding is enabled, while opt_num_tokens is not set, setting to max_batch_size*max_beam_width.
[11/22/2024-06:24:54] [TRT-LLM] [W] padding removal and fMHA are both enabled, max_input_len is not required and will be ignored
[11/22/2024-06:24:57] [TRT] [I] [MemUsageChange] Init CUDA: CPU +14, GPU +0, now: CPU 5665, GPU 322 (MiB)
[11/22/2024-06:25:00] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +2276, GPU +440, now: CPU 8097, GPU 762 (MiB)
[11/22/2024-06:25:00] [TRT-LLM] [I] Set nccl_plugin to None.
[11/22/2024-06:25:00] [TRT] [W] IElementWiseLayer with inputs QWenForCausalLM/transformer/vocab_embedding/where_L2963/SELECT_2_output_0 and QWenForCausalLM/transformer/layers/0/attention/dense/multiply_collect_L272/multiply_and_lora_L238/_gemm_plugin_L129/PLUGIN_V2_Gemm_0_output_0: first input has type BFloat16 but second input has type Half.
[11/22/2024-06:25:01] [TRT] [E] ITensor::getDimensions: Error Code 4: API Usage Error (QWenForCausalLM/transformer/layers/0/__add___L322/elementwise_binary_L2877/ELEMENTWISE_SUM_0: ElementWiseOperation SUM must have same input types. But they are of types BFloat16 and Half.)
Traceback (most recent call last):
File "/home/admin/tensorrt/venv/bin/trtllm-build", line 8, in
sys.exit(main())
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 627, in main
parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 425, in parallel_build
passed = build_and_save(rank, rank % workers, ckpt_dir,
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 390, in build_and_save
engine = build_model(build_config,
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 383, in build_model
return build(model, build_config)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/builder.py", line 1216, in build
model(**inputs)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/module.py", line 52, in call
output = self.forward(*args, **kwargs)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 962, in forward
hidden_states = self.transformer.forward(**kwargs)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/models/qwen/model.py", line 193, in forward
hidden_states = self.layers.forward(
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 550, in forward
hidden_states = layer(
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/module.py", line 52, in call
output = self.forward(*args, **kwargs)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/models/qwen/model.py", line 137, in forward
hidden_states = residual + attention_output
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 322, in add
return add(self, b)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 2877, in elementwise_binary
return _create_tensor(layer.get_output(0), layer)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 608, in _create_tensor
assert trt_tensor.shape.len(
AssertionError: tensor QWenForCausalLM/transformer/layers/0/__add___L322/elementwise_binary_L2877/ELEMENTWISE_SUM_0_output_0 has an invalid shape
additional notes
No clue.
The text was updated successfully, but these errors were encountered:
Setting both --gemm_plugin and --gpt_attention_plugin to =auto resolved the problem, but I still want to know why the datatype wasn't matching. I've also chdecked that the compiled model works fine afterwards in the tensorRT
System Info
Who can help?
@kaiyux (I'm referencing you as you were the editor in the related commit)
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Steps to reproduce the behaviour
1 : Followed step 1 to 3-(build TensorRT-LLM engine) with exact same commands.
2 : Running into error message.
Expected behavior
Build gets done succesfully.
actual behavior
./build_qwenVL.sh <- below(exact same code in the examples/qwenvl)
(venv) admin@inference_l4x4_admin:~/tensorrt$ ./build_qwenVL.sh --verbose
Authorization required, but no authorization protocol specified
Authorization required, but no authorization protocol specified
Authorization required, but no authorization protocol specified
[TensorRT-LLM] TensorRT-LLM version: 0.16.0.dev2024111900
[11/22/2024-06:24:54] [TRT-LLM] [I] Set bert_attention_plugin to auto.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set gpt_attention_plugin to float16.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set gemm_plugin to float16.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set gemm_swiglu_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set fp8_rowwise_gemm_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set nccl_plugin to auto.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set lora_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set moe_plugin to auto.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set low_latency_gemm_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set low_latency_gemm_swiglu_plugin to None.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set context_fmha to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set bert_context_fmha_fp32_acc to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set remove_input_padding to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set reduce_fusion to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set enable_xqa to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set tokens_per_block to 64.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set use_paged_context_fmha to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set use_fp8_context_fmha to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set multiple_profiles to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set paged_state to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set streamingllm to False.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set use_fused_mlp to True.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set pp_reduce_scatter to False.
[11/22/2024-06:24:54] [TRT-LLM] [W] Implicitly setting QWenConfig.qwen_type = qwen
[11/22/2024-06:24:54] [TRT-LLM] [W] Implicitly setting QWenConfig.moe_intermediate_size = 0
[11/22/2024-06:24:54] [TRT-LLM] [W] Implicitly setting QWenConfig.moe_shared_expert_intermediate_size = 0
[11/22/2024-06:24:54] [TRT-LLM] [W] Implicitly setting QWenConfig.tie_word_embeddings = False
[11/22/2024-06:24:54] [TRT-LLM] [I] Set dtype to bfloat16.
[11/22/2024-06:24:54] [TRT-LLM] [I] Set paged_kv_cache to True.
[11/22/2024-06:24:54] [TRT-LLM] [W] Overriding paged_state to False
[11/22/2024-06:24:54] [TRT-LLM] [I] Set paged_state to False.
[11/22/2024-06:24:54] [TRT-LLM] [W] remove_input_padding is enabled, while opt_num_tokens is not set, setting to max_batch_size*max_beam_width.
[11/22/2024-06:24:54] [TRT-LLM] [W] padding removal and fMHA are both enabled, max_input_len is not required and will be ignored
[11/22/2024-06:24:57] [TRT] [I] [MemUsageChange] Init CUDA: CPU +14, GPU +0, now: CPU 5665, GPU 322 (MiB)
[11/22/2024-06:25:00] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +2276, GPU +440, now: CPU 8097, GPU 762 (MiB)
[11/22/2024-06:25:00] [TRT-LLM] [I] Set nccl_plugin to None.
[11/22/2024-06:25:00] [TRT] [W] IElementWiseLayer with inputs QWenForCausalLM/transformer/vocab_embedding/where_L2963/SELECT_2_output_0 and QWenForCausalLM/transformer/layers/0/attention/dense/multiply_collect_L272/multiply_and_lora_L238/_gemm_plugin_L129/PLUGIN_V2_Gemm_0_output_0: first input has type BFloat16 but second input has type Half.
[11/22/2024-06:25:01] [TRT] [E] ITensor::getDimensions: Error Code 4: API Usage Error (QWenForCausalLM/transformer/layers/0/__add___L322/elementwise_binary_L2877/ELEMENTWISE_SUM_0: ElementWiseOperation SUM must have same input types. But they are of types BFloat16 and Half.)
Traceback (most recent call last):
File "/home/admin/tensorrt/venv/bin/trtllm-build", line 8, in
sys.exit(main())
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 627, in main
parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 425, in parallel_build
passed = build_and_save(rank, rank % workers, ckpt_dir,
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 390, in build_and_save
engine = build_model(build_config,
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/commands/build.py", line 383, in build_model
return build(model, build_config)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/builder.py", line 1216, in build
model(**inputs)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/module.py", line 52, in call
output = self.forward(*args, **kwargs)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 962, in forward
hidden_states = self.transformer.forward(**kwargs)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/models/qwen/model.py", line 193, in forward
hidden_states = self.layers.forward(
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/models/modeling_utils.py", line 550, in forward
hidden_states = layer(
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/module.py", line 52, in call
output = self.forward(*args, **kwargs)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/models/qwen/model.py", line 137, in forward
hidden_states = residual + attention_output
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 322, in add
return add(self, b)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 2877, in elementwise_binary
return _create_tensor(layer.get_output(0), layer)
File "/home/admin/tensorrt/venv/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 608, in _create_tensor
assert trt_tensor.shape.len(
AssertionError: tensor QWenForCausalLM/transformer/layers/0/__add___L322/elementwise_binary_L2877/ELEMENTWISE_SUM_0_output_0 has an invalid shape
additional notes
No clue.
The text was updated successfully, but these errors were encountered: