Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] further polish memory profiling #12126

Merged
merged 6 commits into from
Jan 18, 2025

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Jan 16, 2025

Improve upon #11809

current main branch, vllm serve meta-llama/Llama-3.2-11B-Vision --load-format dummy --max-model-len 65536 --max-num-seqs 128 --enforce-eager will fail.

after this PR, it can work on H100-80G now:

INFO 01-16 09:40:00 worker.py:238] the current vLLM instance can use total_gpu_memory (79.22GiB) x gpu_memory_utilization (0.90) = 71.29GiB
INFO 01-16 09:40:00 worker.py:238] model weights take 19.86GiB; non_torch_memory takes 0.16GiB; PyTorch activation peak memory takes 36.63GiB; the rest of the memory reserved for KV Cache is 14.65GiB.
INFO 01-16 09:40:00 executor_base.py:95] # CUDA blocks: 6001, # CPU blocks: 1638

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀


# load weights

weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)

weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor in this PR: remote the _in_bytes in variable name to make the name shorter.

assert abs(non_torch_ratio - 1) <= 0.05
assert abs(torch_peak_ratio - 1) <= 0.05
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now this becomes accurate.

Comment on lines +1937 to +1943
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get(
"allocated_bytes.all.peak", 0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the key change, reported by @gshtras

Comment on lines +141 to +142
torch.cuda.reset_peak_memory_stats()
self.baseline_snapshot = MemorySnapshot()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another key change: we also measure the non-torch memory before creating the vllm instance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do this in v1/worker/gpu_worker.py as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v1 does not use this memory_profiling utility yet. welcome to port it to v1 code path!

@gshtras
Copy link
Contributor

gshtras commented Jan 16, 2025

Ran some tests and can confirm that this fixes the original issue we had with Llama3.2 90B model peak memory jumping from 98GB to 160GB

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for wrestling this into a better state!

@youkaichao youkaichao enabled auto-merge (squash) January 17, 2025 15:48
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 17, 2025
@youkaichao
Copy link
Member Author

failed tests look unrelated, merging.

@youkaichao youkaichao disabled auto-merge January 18, 2025 04:25
@youkaichao youkaichao merged commit da02cb4 into vllm-project:main Jan 18, 2025
67 of 70 checks passed
@youkaichao youkaichao deleted the fix_profiling branch January 18, 2025 04:25
joennlae pushed a commit to 44ai-labs/vllm that referenced this pull request Jan 19, 2025
joennlae pushed a commit to 44ai-labs/vllm that referenced this pull request Jan 19, 2025
lckr pushed a commit to lckr/vllm that referenced this pull request Jan 19, 2025
abmfy pushed a commit to abmfy/vllm-flashinfer that referenced this pull request Jan 24, 2025
abmfy pushed a commit to abmfy/vllm-flashinfer that referenced this pull request Jan 24, 2025
kzawora-intel added a commit to HabanaAI/vllm-fork that referenced this pull request Jan 28, 2025
- **[Bugfix] Fix score api for missing max_model_len validation
(vllm-project#12119)**
- **[Bugfix] Mistral tokenizer encode accept list of str (vllm-project#12149)**
- **[AMD][FP8] Using MI300 FP8 format on ROCm for block_quant (vllm-project#12134)**
- **[torch.compile] disable logging when cache is disabled (vllm-project#12043)**
- **[misc] fix cross-node TP (vllm-project#12166)**
- **[AMD][CI/Build][Bugfix] use pytorch stale wheel (vllm-project#12172)**
- **[core] further polish memory profiling (vllm-project#12126)**
- **[Docs] Fix broken link in SECURITY.md (vllm-project#12175)**
- **[Model] Port deepseek-vl2 processor, remove dependency (vllm-project#12169)**
- **[core] clean up executor class hierarchy between v1 and v0
(vllm-project#12171)**
- **[Misc] Support register quantization method out-of-tree (vllm-project#11969)**
- **[V1] Collect env var for usage stats (vllm-project#12115)**
- **[BUGFIX] Move scores to float32 in case of running xgrammar on cpu
(vllm-project#12152)**
- **[Bugfix] Fix multi-modal processors for transformers 4.48 (vllm-project#12187)**
- **[torch.compile] store inductor compiled Python file (vllm-project#12182)**
- **benchmark_serving support --served-model-name param (vllm-project#12109)**
- **[Misc] Add BNB support to GLM4-V model (vllm-project#12184)**
- **[V1] Add V1 support of Qwen2-VL (vllm-project#12128)**
- **[Model] Support for fairseq2 Llama (vllm-project#11442)**
- **[Bugfix] Fix num_heads value for simple connector when tp enabled
(vllm-project#12074)**
- **[torch.compile] fix sym_tensor_indices (vllm-project#12191)**
- **Move linting to `pre-commit` (vllm-project#11975)**
- **[DOC] Fix typo in docstring and assert message (vllm-project#12194)**
- **[DOC] Add missing docstring in LLMEngine.add_request() (vllm-project#12195)**
- **[Bugfix] Fix incorrect types in LayerwiseProfileResults (vllm-project#12196)**
- **[Model] Add Qwen2 PRM model support (vllm-project#12202)**
- **[Core] Interface for accessing model from `VllmRunner` (vllm-project#10353)**
- **[misc] add placeholder format.sh (vllm-project#12206)**
- **[CI/Build] Remove dummy CI steps (vllm-project#12208)**
- **[CI/Build] Make pre-commit faster (vllm-project#12212)**
- **[Model] Upgrade Aria to transformers 4.48 (vllm-project#12203)**
- **[misc] print a message to suggest how to bypass commit hooks
(vllm-project#12217)**
- **[core][bugfix] configure env var during import vllm (vllm-project#12209)**
- **[V1] Remove `_get_cache_block_size` (vllm-project#12214)**
- **[Misc] Pass `attention` to impl backend (vllm-project#12218)**
- **[Bugfix] Fix `HfExampleModels.find_hf_info` (vllm-project#12223)**
- **[CI] Pass local python version explicitly to pre-commit mypy.sh
(vllm-project#12224)**
- **[Misc] Update CODEOWNERS (vllm-project#12229)**
- **fix: update platform detection for M-series arm based MacBook
processors (vllm-project#12227)**
- **[misc] add cuda runtime version to usage data (vllm-project#12190)**
- **[bugfix] catch xgrammar unsupported array constraints (vllm-project#12210)**
- **[Kernel] optimize moe_align_block_size for cuda graph and large
num_experts (e.g. DeepSeek-V3) (vllm-project#12222)**
- **Add quantization and guided decoding CODEOWNERS (vllm-project#12228)**
- **[AMD][Build] Porting dockerfiles from the ROCm/vllm fork (vllm-project#11777)**
- **[BugFix] Fix GGUF tp>1 when vocab_size is not divisible by 64
(vllm-project#12230)**
- **[ci/build] disable failed and flaky tests (vllm-project#12240)**
- **[Misc] Rename `MultiModalInputsV2 -> MultiModalInputs` (vllm-project#12244)**
- **[Misc]Add BNB quantization for PaliGemmaForConditionalGeneration
(vllm-project#12237)**
- **[Misc] Remove redundant TypeVar from base model (vllm-project#12248)**
- **[Bugfix] Fix mm_limits access for merged multi-modal processor
(vllm-project#12252)**

---------

Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: hongxyan <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Michal Adamczyk <[email protected]>
Signed-off-by: zibai <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Martin Gleize <[email protected]>
Signed-off-by: Shangming Cai <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Yuan Tang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: isikhi <[email protected]>
Signed-off-by: Jason Cheng <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Konrad Zawora <[email protected]>
Co-authored-by: Wallas Henrique <[email protected]>
Co-authored-by: Kunshang Ji <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: yancong <[email protected]>
Co-authored-by: Simon Mo <[email protected]>
Co-authored-by: Michal Adamczyk <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: gujing <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: imkero <[email protected]>
Co-authored-by: Martin Gleize <[email protected]>
Co-authored-by: mgleize user <[email protected]>
Co-authored-by: shangmingc <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Yuan Tang <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: wangxiyuan <[email protected]>
Co-authored-by: Işık <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Cheng Kuan Yong Jason <[email protected]>
Co-authored-by: Jinzhen Lin <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants