Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Multi-Node Tensor-Parallel in #11256 forces TP > cuda_device_count per node #12132

Closed
1 task done
drikster80 opened this issue Jan 16, 2025 · 4 comments · Fixed by #12166
Closed
1 task done
Assignees
Labels
bug Something isn't working

Comments

@drikster80
Copy link
Contributor

Your current environment

Running in Docker container (Kubernetes) on (4) GH200 nodes. 1 GPU per node.

Model Input Dumps

python3 -m vllm.entrypoints.openai.api_server --model /models/my-model
--tensor-parallel-size 4
--gpu-memory-utilization 0.95
--served-model-name my-model
--trust-remote-code
--api-key "NONE"
--rope-scaling '{"rope_type":"dynamic","factor":4.0}'
--enable-prefix-caching
--max-model-len 131072

🐛 Describe the bug

@youkaichao, it looks like #11256 forces --tensor-parallel-size to be > per node GPU.

https://github.com/vllm-project/vllm/blob/main/vllm/platforms/cuda.py#L156

        # Use confusing message for more common TP-only case.
        assert tensor_parallel_size <= cuda_device_count, (
            f"please set tensor_parallel_size ({tensor_parallel_size}) "
            f"to less than max local gpu count ({cuda_device_count})")

Currently testing main with (4) nodes, (1) GPU per node results in (same model/code/execution works perfectly in v0.6.4.post1):

Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 382, in run_mp_engine
    raise e
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 371, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 115, in from_engine_args
    engine_config = engine_args.create_engine_config(usage_context)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/arg_utils.py", line 1244, in create_engine_config
    config = VllmConfig(
             ^^^^^^^^^^^
  File "<string>", line 19, in __init__
  File "/usr/local/lib/python3.12/dist-packages/vllm/config.py", line 3204, in __post_init__
    current_platform.check_and_update_config(self)
  File "/usr/local/lib/python3.12/dist-packages/vllm/platforms/cuda.py", line 156, in check_and_update_config
    assert tensor_parallel_size <= cuda_device_count, (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: please set tensor_parallel_size (4) to less than max local gpu count (1)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@drikster80 drikster80 added the bug Something isn't working label Jan 16, 2025
@youkaichao youkaichao self-assigned this Jan 17, 2025
@youkaichao
Copy link
Member

will fix soon

@youkaichao
Copy link
Member

@drikster80 can you please help check if #12166 fixes it?

@drikster80
Copy link
Contributor Author

@drikster80 can you please help check if #12166 fixes it?

Thanks. Rebuilding now and will test. Should take about 90 min to build/test.

@drikster80
Copy link
Contributor Author

@drikster80 can you please help check if #12166 fixes it?

Thanks. Rebuilding now and will test. Should take about 90 min to build/test.

@youkaichao, just tested and verified that it works as expected. Thanks for your help! Once merged, this issue can be closed.

@youkaichao youkaichao linked a pull request Jan 18, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants