Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] improve memory profiling #11809

Merged
merged 9 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from vllm_test_utils import monitor

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
Expand Down Expand Up @@ -289,16 +290,32 @@ def test_memory_profiling():

weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB

def measure_current_non_torch():
free, total = torch.cuda.mem_get_info()
current_used = total - free
current_torch = torch.cuda.memory_reserved()
current_non_torch = current_used - current_torch
return current_non_torch

with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
weights_memory_in_bytes=weights_memory_in_bytes) as result:
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
monitor(measure_current_non_torch) as monitored_values:
# make a memory spike, 1 GiB
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
del spike

# Add some extra non-torch memory 256 MiB (simulate NCCL)
handle2 = lib.cudaMalloc(256 * 1024 * 1024)

# this is an analytic value, it is exact,
# we only have 256 MiB non-torch memory increase
measured_diff = monitored_values.values[-1] - monitored_values.values[0]
assert measured_diff == 256 * 1024 * 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool!

I'm a bit confused on how it's useful though, since we're testing a test utility that isn't used in the actual memory profiling? Did we want to enable monitor(measure_current_non_torch) during the actual profile run to try to get an accurate measure of the peak non-torch memory usage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused on how it's useful though

for the current memory profiling, we mainly use torch.cuda.memory_reserved() to replace torch.cuda.memory_stats()["allocated_bytes.all.current"] .

the utility function is more about future-proof, we can get the ground-truth non-torch memory, which will help us profile which part of memory can be offloaded in RLHF workload.


# Check that the memory usage is within 5% of the expected values
# 5% tolerance is caused by PyTorch caching allocator,
# we cannot control PyTorch's behavior of its internal buffers,
# which causes a small error (<10 MiB in practice)
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
assert abs(non_torch_ratio - 1) <= 0.05
Expand Down
3 changes: 2 additions & 1 deletion tests/vllm_test_utils/vllm_test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"""

from .blame import BlameResult, blame
from .monitor import MonitoredValues, monitor

__all__ = ["blame", "BlameResult"]
__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"]
68 changes: 68 additions & 0 deletions tests/vllm_test_utils/vllm_test_utils/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator, Generic, TypeVar

_T = TypeVar("_T")


@dataclasses.dataclass
class MonitoredValues(Generic[_T]):
values: list[_T] = dataclasses.field(default_factory=list)
trace_stacks: list[str] = dataclasses.field(default_factory=list)


@contextlib.contextmanager
def monitor(
measure_func: Callable[[],
_T]) -> Generator[MonitoredValues[_T], None, None]:
"""
Trace the function calls to continuously monitor the change of
a value.

Usage:

```python

def measure_func():
... # measure the current value
return current_value

with monitor(measure_func) as monitored_values:
# do something

monitored_values.values # all changes of the values
monitored_values.trace_stacks # trace stacks of every change
```
"""
monitored_values = MonitoredValues[_T]()

def _trace_calls(frame, event, arg=None):
nonlocal monitored_values
if event in ['line']:
# triggered by every line of Python code.
# only Python functions will trigger it,
# c/cpp functions will not trigger it.
try:
# Temporarily disable the trace function
sys.settrace(None)
# do a measurement
current_value = measure_func()
if len(monitored_values.values
) == 0 or current_value != monitored_values.values[-1]:
monitored_values.values.append(current_value)
monitored_values.trace_stacks.append("".join(
traceback.format_stack()))
# Re-enable the trace function
sys.settrace(_trace_calls)
except NameError:
# modules are deleted during shutdown
pass
return _trace_calls

try:
sys.settrace(_trace_calls)
yield monitored_values
finally:
sys.settrace(None)
12 changes: 6 additions & 6 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,10 +1742,10 @@ class MemorySnapshot:
timestamp: float = 0.0

def measure(self):
self.torch_peak_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.peak"]
self.torch_memory_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh nice, I didn't catch that there was a peak measurement for the total reserved memory as well

# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
self.torch_memory_in_bytes = torch.cuda.memory_reserved()
self.timestamp = time.time()

def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
Expand Down Expand Up @@ -1822,10 +1822,10 @@ def memory_profiling(

The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.

The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).

(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
""" # noqa
torch.cuda.reset_peak_memory_stats()

Expand Down
Loading