Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama 3.1 70B Instruct would not build engine "TypeError: set_shape(): incompatible function arguments." #2018

Closed
2 of 4 tasks
christian-ci opened this issue Jul 24, 2024 · 7 comments
Labels
bug Something isn't working functionality issue

Comments

@christian-ci
Copy link

System Info

  • x86_64
  • GPU Mem: 640GB
  • CPU Mem: 1.5TB
  • 8 * NVIDIA H100
  • TensorRT-LLM Version: 0.12.0.dev2024072301
  • TensorRT-LLM Commit 5fa9436e17c2f9aeace070f49aa645d2577f676b
  • TensorRT-LLM Backend Commit: a6aa8eb6ce9371521df166c480e10262cd9c0cf4
  • NVIDIA Driver Version (Node): 535.183.01
  • CUDA Version (Node): 12.2
  • NVIDIA Driver Version (tensorrt-llm-backend-image): 535.183.01
  • CUDA Version (tensorrt-llm-backend-image): 12.4
  • OS: Ubuntu 22.04

Who can help?

@byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Any of the scripts below will yield the same error:

# Build LLaMA 70B using 8-way tensor parallelism.
python convert_checkpoint.py --model_dir ./tmp/llama/70B/hf/ \
                            --output_dir ./tllm_checkpoint_8gpu_tp8 \
                            --dtype float16 \
                            --tp_size 8

trtllm-build --checkpoint_dir ./tllm_checkpoint_8gpu_tp8 \
            --output_dir ./tmp/llama/70B/trt_engines/fp16/8-gpu/ \
            --gemm_plugin auto

# Build LLaMA 70B using 4-way tensor parallelism and 2-way pipeline parallelism.
python convert_checkpoint.py --model_dir ./tmp/llama/70B/hf/ \
                            --output_dir ./tllm_checkpoint_8gpu_tp4_pp2 \
                            --dtype float16 \
                            --tp_size 4 \
                            --pp_size 2

trtllm-build --checkpoint_dir ./tllm_checkpoint_8gpu_tp4_pp2 \
            --output_dir ./tmp/llama/70B/trt_engines/fp16/8-gpu/ \
            --gemm_plugin auto

# Quantize HF LLaMA 7B into FP8 and export trtllm checkpoint
python ../quantization/quantize.py --model_dir /tmp/llama-7b-hf \
                                   --dtype float16 \
                                   --qformat fp8 \
                                   --kv_cache_dtype fp8 \
                                   --output_dir ./tllm_checkpoint_1gpu_fp8 \
                                   --calib_size 512

# Build trtllm engines from the trtllm checkpoint
# Enable fp8 gemm plugin to get acceleration in small-batch-size cases
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp8 \
             --output_dir ./engine_outputs \
             --gemm_plugin fp8

Expected behavior

It builds the engine

actual behavior

root@dev-station:/app/tensorrt_llm/examples/llama# trtllm-build --checkpoint_dir <checkpoint-dir>
            --output_dir <engine-dir> \
            --gemm_plugin auto
[TensorRT-LLM] TensorRT-LLM version: 0.12.0.dev2024072301
[07/24/2024-18:43:28] [TRT-LLM] [I] Set bert_attention_plugin to auto.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set gpt_attention_plugin to auto.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set gemm_plugin to auto.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set gemm_swiglu_plugin to None.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set fp8_rowwise_gemm_plugin to None.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set nccl_plugin to auto.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set lookup_plugin to None.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set lora_plugin to None.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set moe_plugin to auto.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set context_fmha to True.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set context_fmha_fp32_acc to False.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set paged_kv_cache to True.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set remove_input_padding to True.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set reduce_fusion to False.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set enable_xqa to True.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set tokens_per_block to 64.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set use_paged_context_fmha to False.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set use_fp8_context_fmha to False.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set multiple_profiles to False.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set paged_state to True.
[07/24/2024-18:43:28] [TRT-LLM] [I] Set streamingllm to False.
[07/24/2024-18:43:29] [TRT-LLM] [W] max_seq_len is scaled to 1048576.0 by rotary scaling 8.0
[07/24/2024-18:43:29] [TRT-LLM] [I] max_seq_len is not specified, using value 1048576.0
[07/24/2024-18:43:29] [TRT-LLM] [W] remove_input_padding is enabled, while opt_num_tokens is not set, setting to max_batch_size*max_beam_width. 

[07/24/2024-18:43:29] [TRT-LLM] [W] padding removal and fMHA are both enabled, max_input_len is not required and will be ignored
[07/24/2024-18:43:33] [TRT-LLM] [I] Set dtype to float16.
[07/24/2024-18:43:34] [TRT] [I] [MemUsageChange] Init CUDA: CPU +16, GPU +0, now: CPU 146, GPU 529 (MiB)
[07/24/2024-18:43:40] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +4518, GPU +1250, now: CPU 4817, GPU 1779 (MiB)
[07/24/2024-18:43:40] [TRT] [W] profileSharing0806 is on by default in TensorRT 10.0. This flag is deprecated and has no effect.
[07/24/2024-18:43:40] [TRT-LLM] [I] Set nccl_plugin to float16.
Traceback (most recent call last):
  File "/usr/local/bin/trtllm-build", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 535, in main
    parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 371, in parallel_build
    passed = build_and_save(rank, rank % workers, ckpt_dir,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 338, in build_and_save
    engine = build_model(build_config,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 331, in build_model
    return build(model, build_config)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/builder.py", line 923, in build
    engine = None if build_config.dry_run else builder.build_engine(
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/_common.py", line 201, in decorated
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/builder.py", line 371, in build_engine
    self._add_optimization_profile(network, builder_config)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/builder.py", line 263, in _add_optimization_profile
    profile.set_shape(input_name, min_shape, opt_shape, max_shape)
TypeError: set_shape(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt.tensorrt.IOptimizationProfile, input: str, min: tensorrt.tensorrt.Dims, opt: tensorrt.tensorrt.Dims, max: tensorrt.tensorrt.Dims) -> None

Invoked with: <tensorrt.tensorrt.IOptimizationProfile object at 0x7a6e36336230>, 'cache_indirection', [1, 1, 1], [128, 1, 524288.0], [256, 1, 1048576.0]

additional notes

No Additional Notes

@christian-ci christian-ci added the bug Something isn't working label Jul 24, 2024
@QiJune
Copy link
Collaborator

QiJune commented Jul 25, 2024

Hi @christian-ci , could you please share the config.json in the <checkpoint-dir> ? Thanks

@christian-ci
Copy link
Author

christian-ci commented Jul 25, 2024

Hi @QiJune Here it is:
For tp_size 8:

{
    "mlp_bias": false,
    "attn_bias": false,
    "rotary_base": 500000.0,
    "rotary_scaling": {
        "factor": 8.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3"
    },
    "residual_mlp": false,
    "disable_weight_only_quant_plugin": false,
    "moe": {
        "num_experts": 0,
        "top_k": 0,
        "normalization_mode": null,
        "tp_mode": 0
    },
    "architecture": "LlamaForCausalLM",
    "dtype": "float16",
    "vocab_size": 128256,
    "hidden_size": 8192,
    "num_hidden_layers": 80,
    "num_attention_heads": 64,
    "hidden_act": "silu",
    "logits_dtype": "float32",
    "norm_epsilon": 1e-05,
    "position_embedding_type": "rope_gpt_neox",
    "max_position_embeddings": 131072,
    "num_key_value_heads": 8,
    "intermediate_size": 28672,
    "mapping": {
        "world_size": 8,
        "gpus_per_node": 8,
        "tp_size": 8,
        "pp_size": 1,
        "moe_tp_size": 8,
        "moe_ep_size": 1
    },
    "quantization": {
        "quant_algo": null,
        "kv_cache_quant_algo": null,
        "group_size": 128,
        "smoothquant_val": null,
        "clamp_val": null,
        "has_zero_point": false,
        "pre_quant_scale": false,
        "exclude_modules": null
    },
    "use_parallel_embedding": false,
    "embedding_sharding_dim": 0,
    "share_embedding_table": false,
    "head_size": 128,
    "qk_layernorm": false
}

@QiJune
Copy link
Collaborator

QiJune commented Jul 25, 2024

Hi @christian-ci , I can reproduce this error. You can change this line: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/commands/build.py#L465

from

deduced_max_seq_len *= rotary_factor

to

deduced_max_seq_len = int(deduced_max_seq_len  *  rotary_factor)

to workaround quickly.

We will fix it in the internal repo, which will be updated to github in the next week.

@christian-ci
Copy link
Author

@QiJune Finally tested. Make sure to change that before building the image for everyone else. Thanks!

@LanceB57
Copy link

LanceB57 commented Jul 30, 2024

Hey all, I'm getting the same error trying to run 3.1 8B.

My command:
trtllm-build --checkpoint_dir llama_ckpt --output_dir ./llama_engine --workers 2 --multiple_profiles enable --use_fused_mlp --gpt_attention_plugin bfloat16 --reduce_fusion disable

The error:

concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 337, in build_and_save
    engine = build_model(build_config,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 330, in build_model
    return build(model, build_config)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/builder.py", line 971, in build
    engine = None if build_config.dry_run else builder.build_engine(
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/_common.py", line 201, in decorated
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/builder.py", line 372, in build_engine
    self._add_optimization_profile(network, builder_config)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/builder.py", line 264, in _add_optimization_profile
    profile.set_shape(input_name, min_shape, opt_shape, max_shape)
TypeError: set_shape(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt.tensorrt.IOptimizationProfile, input: str, min: tensorrt.tensorrt.Dims, opt: tensorrt.tensorrt.Dims, max: tensorrt.tensorrt.Dims) -> None

Invoked with: <tensorrt.tensorrt.IOptimizationProfile object at 0x7f1c8354c3f0>, 'cache_indirection', [1, 1, 1], [128, 1, 524288.0], [256, 1, 1048576.0]
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/commands/build.py", line 385, in parallel_build
    future.result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 451, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
TypeError: set_shape(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt.tensorrt.IOptimizationProfile, input: str, min: tensorrt.tensorrt.Dims, opt: tensorrt.tensorrt.Dims, max: tensorrt.tensorrt.Dims) -> None

Invoked with: <tensorrt.tensorrt.IOptimizationProfile object at 0x7f1c8354c3f0>, 'cache_indirection', [1, 1, 1], [128, 1, 524288.0], [256, 1, 1048576.0]

It also seems like the file @QiJune referenced is different now, so I'm not sure where to make that fix.

@christian-ci
Copy link
Author

@LanceB57 You need to change line:

deduced_max_seq_len *= rotary_factor

to:

deduced_max_seq_len = int(deduced_max_seq_len  *  rotary_factor)

and rebuild tensorrt-llm. It wont work if you change the line and nor rebuild it.

@LanceB57
Copy link

Great, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working functionality issue
Projects
None yet
Development

No branches or pull requests

3 participants