From 889e662eae19fe8f30469883c6854ee4df4315a9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 Jan 2025 14:36:03 +0800 Subject: [PATCH] [misc] improve memory profiling (#11809) Signed-off-by: youkaichao Co-authored-by: Cyrus Leung --- tests/test_utils.py | 19 +++++- .../vllm_test_utils/__init__.py | 3 +- .../vllm_test_utils/monitor.py | 68 +++++++++++++++++++ vllm/utils.py | 12 ++-- 4 files changed, 94 insertions(+), 8 deletions(-) create mode 100644 tests/vllm_test_utils/vllm_test_utils/monitor.py diff --git a/tests/test_utils.py b/tests/test_utils.py index 32a6b0aed66aa..0285b00d73be1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,6 +5,7 @@ import pytest import torch +from vllm_test_utils import monitor from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, get_open_port, memory_profiling, merge_async_iterators, @@ -289,8 +290,16 @@ def test_memory_profiling(): weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB + def measure_current_non_torch(): + free, total = torch.cuda.mem_get_info() + current_used = total - free + current_torch = torch.cuda.memory_reserved() + current_non_torch = current_used - current_torch + return current_non_torch + with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes, - weights_memory_in_bytes=weights_memory_in_bytes) as result: + weights_memory_in_bytes=weights_memory_in_bytes) as result, \ + monitor(measure_current_non_torch) as monitored_values: # make a memory spike, 1 GiB spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) del spike @@ -298,7 +307,15 @@ def test_memory_profiling(): # Add some extra non-torch memory 256 MiB (simulate NCCL) handle2 = lib.cudaMalloc(256 * 1024 * 1024) + # this is an analytic value, it is exact, + # we only have 256 MiB non-torch memory increase + measured_diff = monitored_values.values[-1] - monitored_values.values[0] + assert measured_diff == 256 * 1024 * 1024 + # Check that the memory usage is within 5% of the expected values + # 5% tolerance is caused by PyTorch caching allocator, + # we cannot control PyTorch's behavior of its internal buffers, + # which causes a small error (<10 MiB in practice) non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa assert abs(non_torch_ratio - 1) <= 0.05 diff --git a/tests/vllm_test_utils/vllm_test_utils/__init__.py b/tests/vllm_test_utils/vllm_test_utils/__init__.py index bf0b62a5b75e3..6505c81546bb0 100644 --- a/tests/vllm_test_utils/vllm_test_utils/__init__.py +++ b/tests/vllm_test_utils/vllm_test_utils/__init__.py @@ -4,5 +4,6 @@ """ from .blame import BlameResult, blame +from .monitor import MonitoredValues, monitor -__all__ = ["blame", "BlameResult"] +__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"] diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py new file mode 100644 index 0000000000000..a237f53a75d18 --- /dev/null +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -0,0 +1,68 @@ +import contextlib +import dataclasses +import sys +import traceback +from typing import Callable, Generator, Generic, TypeVar + +_T = TypeVar("_T") + + +@dataclasses.dataclass +class MonitoredValues(Generic[_T]): + values: list[_T] = dataclasses.field(default_factory=list) + trace_stacks: list[str] = dataclasses.field(default_factory=list) + + +@contextlib.contextmanager +def monitor( + measure_func: Callable[[], + _T]) -> Generator[MonitoredValues[_T], None, None]: + """ + Trace the function calls to continuously monitor the change of + a value. + + Usage: + + ```python + + def measure_func(): + ... # measure the current value + return current_value + + with monitor(measure_func) as monitored_values: + # do something + + monitored_values.values # all changes of the values + monitored_values.trace_stacks # trace stacks of every change + ``` + """ + monitored_values = MonitoredValues[_T]() + + def _trace_calls(frame, event, arg=None): + nonlocal monitored_values + if event in ['line']: + # triggered by every line of Python code. + # only Python functions will trigger it, + # c/cpp functions will not trigger it. + try: + # Temporarily disable the trace function + sys.settrace(None) + # do a measurement + current_value = measure_func() + if len(monitored_values.values + ) == 0 or current_value != monitored_values.values[-1]: + monitored_values.values.append(current_value) + monitored_values.trace_stacks.append("".join( + traceback.format_stack())) + # Re-enable the trace function + sys.settrace(_trace_calls) + except NameError: + # modules are deleted during shutdown + pass + return _trace_calls + + try: + sys.settrace(_trace_calls) + yield monitored_values + finally: + sys.settrace(None) diff --git a/vllm/utils.py b/vllm/utils.py index 63057153f851d..2660b53d7bfb0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1742,10 +1742,10 @@ class MemorySnapshot: timestamp: float = 0.0 def measure(self): - self.torch_peak_in_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.peak"] - self.torch_memory_in_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] + self.torch_peak_in_bytes = torch.cuda.max_memory_reserved() + # torch.cuda.memory_reserved() is how many bytes + # PyTorch gets from cuda (by calling cudaMalloc, etc.) + self.torch_memory_in_bytes = torch.cuda.memory_reserved() self.timestamp = time.time() def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": @@ -1822,10 +1822,10 @@ def memory_profiling( The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`. - The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). + The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), - subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`. + subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`. """ # noqa torch.cuda.reset_peak_memory_stats()