Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core][1/n] Logging and Metrics #11962

Merged
merged 10 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

Expand Down Expand Up @@ -394,12 +395,12 @@ def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[EngineCoreOutput]:
) -> EngineCoreOutputs:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
engine_core_outputs: List[EngineCoreOutput] = []
outputs: List[EngineCoreOutput] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
Expand Down Expand Up @@ -438,15 +439,18 @@ def update_from_output(
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
engine_core_outputs.append(output)
outputs.append(output)

# Breakout of the loop.
if stopped:
continue

new_running.append(request)
self.running = new_running
return engine_core_outputs
return EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
)

def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
Expand Down Expand Up @@ -515,6 +519,12 @@ def get_num_unfinished_requests(self) -> int:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0

def make_stats(self) -> SchedulerStats:
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
)


@dataclass
class NewRequestData:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import msgspec

from vllm.v1.metrics.stats import SchedulerStats

if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
Expand Down Expand Up @@ -56,6 +58,7 @@ class EngineCoreOutputs(

# [num_reqs]
outputs: List[EngineCoreOutput]
scheduler_stats: SchedulerStats


@dataclass
Expand Down
26 changes: 19 additions & 7 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
Expand All @@ -22,6 +21,8 @@
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase
from vllm.v1.metrics.stats import SchedulerStats

logger = init_logger(__name__)

Expand All @@ -34,7 +35,6 @@ def __init__(
executor_class: Type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
Expand All @@ -45,7 +45,10 @@ def __init__(

self.log_requests = log_requests
self.log_stats = log_stats
self.stat_loggers = stat_loggers
self.stat_loggers: List[StatLoggerBase] = [
LoggingStatLogger(),
# PrometheusStatLogger(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add TODO?

]
self.model_config = vllm_config.model_config

# Tokenizer (+ ensure liveness if running in another process).
Expand Down Expand Up @@ -82,7 +85,6 @@ def __init__(
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)

self.output_handler: Optional[asyncio.Task] = None
Expand All @@ -94,7 +96,6 @@ def from_engine_args(
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""

Expand All @@ -114,7 +115,6 @@ def from_engine_args(
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)

def shutdown(self):
Expand Down Expand Up @@ -254,14 +254,18 @@ async def _run_output_handler(self):
outputs = await self.engine_core.get_output_async()

# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
request_outputs, reqs_to_abort = self.detokenizer.step(
outputs.outputs)

# 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs)

# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)

# 5) Log any stats.
await self._log_stats(scheduler_stats=outputs.scheduler_stats)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to improve the metric system later to remove it from the critical path? Or its overhead is acceptable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not in the EngineCore process, so it is overlapped with GPU execution (and therefore is not in the critical path). If this becomes a bottleneck for latency we can offload to a 3rd process but I don’t except this to be needed.


except Exception as e:
logger.exception("EngineCore output handler hit an error: %s", e)
kill_process_tree(os.getpid())
Expand All @@ -278,6 +282,14 @@ async def abort(self, request_id: str) -> None:
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]

async def _log_stats(self, scheduler_stats: SchedulerStats):
"""Log stats to the stat loggers."""
if not self.log_stats:
return

for logger in self.stat_loggers:
logger.log(scheduler_stats=scheduler_stats)

def encode(
self,
prompt: PromptType,
Expand Down
53 changes: 16 additions & 37 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
Expand All @@ -28,9 +28,7 @@

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5
POLLING_TIMEOUT_S = 2.5


class EngineCore:
Expand All @@ -40,10 +38,8 @@ def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
):
assert vllm_config.model_config.runner_type != "pooling"
self.log_stats = log_stats

logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
Expand All @@ -62,8 +58,6 @@ def __init__(
vllm_config.cache_config,
vllm_config.lora_config)

self._last_logging_time = time.time()

self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)

Expand Down Expand Up @@ -114,11 +108,12 @@ def abort_requests(self, request_ids: List[str]):
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)

def step(self) -> List[EngineCoreOutput]:
def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output."""

if not self.scheduler.has_unfinished_requests():
return []
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())

scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
Expand All @@ -145,15 +140,17 @@ def __init__(
executor_class: Type[Executor],
log_stats: bool = False,
):
super().__init__(vllm_config, executor_class, log_stats)
super().__init__(vllm_config, executor_class)

self.log_stats = log_stats

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
daemon=True).start()
Expand Down Expand Up @@ -217,7 +214,9 @@ def run_busy_loop(self):
self._handle_client_request(req)
break
except queue.Empty:
self._log_stats()
# Break out the loops so we can log_stats via step().
if self.log_stats:
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move after the logger.debug?

logger.debug("EngineCore busy loop waiting.")
except BaseException:
raise
Expand All @@ -230,28 +229,9 @@ def run_busy_loop(self):
# 3) Step the engine core.
outputs = self.step()

# 4) Put EngineCoreOutputs into the output queue.
# 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)

self._log_stats()

def _log_stats(self):
"""Log basic stats every LOGGING_TIME_S"""

if not self.log_stats:
return

now = time.time()

if now - self._last_logging_time > LOGGING_TIME_S:
logger.info(
"RUNNING: %s | WAITING: %s",
len(self.scheduler.running),
len(self.scheduler.waiting),
)

self._last_logging_time = now

def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""

Expand Down Expand Up @@ -301,7 +281,6 @@ def process_output_socket(self, output_path: str):

with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
outputs = self.output_queue.get()
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False)
Loading
Loading