Skip to content

Commit

Permalink
[misc] improve memory profiling (vllm-project#11809)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
youkaichao and DarkLight1337 authored Jan 8, 2025
1 parent ef68eb2 commit 889e662
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 8 deletions.
19 changes: 18 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from vllm_test_utils import monitor

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
Expand Down Expand Up @@ -289,16 +290,32 @@ def test_memory_profiling():

weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB

def measure_current_non_torch():
free, total = torch.cuda.mem_get_info()
current_used = total - free
current_torch = torch.cuda.memory_reserved()
current_non_torch = current_used - current_torch
return current_non_torch

with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
weights_memory_in_bytes=weights_memory_in_bytes) as result:
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
monitor(measure_current_non_torch) as monitored_values:
# make a memory spike, 1 GiB
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
del spike

# Add some extra non-torch memory 256 MiB (simulate NCCL)
handle2 = lib.cudaMalloc(256 * 1024 * 1024)

# this is an analytic value, it is exact,
# we only have 256 MiB non-torch memory increase
measured_diff = monitored_values.values[-1] - monitored_values.values[0]
assert measured_diff == 256 * 1024 * 1024

# Check that the memory usage is within 5% of the expected values
# 5% tolerance is caused by PyTorch caching allocator,
# we cannot control PyTorch's behavior of its internal buffers,
# which causes a small error (<10 MiB in practice)
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
assert abs(non_torch_ratio - 1) <= 0.05
Expand Down
3 changes: 2 additions & 1 deletion tests/vllm_test_utils/vllm_test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"""

from .blame import BlameResult, blame
from .monitor import MonitoredValues, monitor

__all__ = ["blame", "BlameResult"]
__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"]
68 changes: 68 additions & 0 deletions tests/vllm_test_utils/vllm_test_utils/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator, Generic, TypeVar

_T = TypeVar("_T")


@dataclasses.dataclass
class MonitoredValues(Generic[_T]):
values: list[_T] = dataclasses.field(default_factory=list)
trace_stacks: list[str] = dataclasses.field(default_factory=list)


@contextlib.contextmanager
def monitor(
measure_func: Callable[[],
_T]) -> Generator[MonitoredValues[_T], None, None]:
"""
Trace the function calls to continuously monitor the change of
a value.
Usage:
```python
def measure_func():
... # measure the current value
return current_value
with monitor(measure_func) as monitored_values:
# do something
monitored_values.values # all changes of the values
monitored_values.trace_stacks # trace stacks of every change
```
"""
monitored_values = MonitoredValues[_T]()

def _trace_calls(frame, event, arg=None):
nonlocal monitored_values
if event in ['line']:
# triggered by every line of Python code.
# only Python functions will trigger it,
# c/cpp functions will not trigger it.
try:
# Temporarily disable the trace function
sys.settrace(None)
# do a measurement
current_value = measure_func()
if len(monitored_values.values
) == 0 or current_value != monitored_values.values[-1]:
monitored_values.values.append(current_value)
monitored_values.trace_stacks.append("".join(
traceback.format_stack()))
# Re-enable the trace function
sys.settrace(_trace_calls)
except NameError:
# modules are deleted during shutdown
pass
return _trace_calls

try:
sys.settrace(_trace_calls)
yield monitored_values
finally:
sys.settrace(None)
12 changes: 6 additions & 6 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,10 +1742,10 @@ class MemorySnapshot:
timestamp: float = 0.0

def measure(self):
self.torch_peak_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.peak"]
self.torch_memory_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
self.torch_memory_in_bytes = torch.cuda.memory_reserved()
self.timestamp = time.time()

def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
Expand Down Expand Up @@ -1822,10 +1822,10 @@ def memory_profiling(
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
""" # noqa
torch.cuda.reset_peak_memory_stats()

Expand Down

0 comments on commit 889e662

Please sign in to comment.