Skip to content

Commit

Permalink
[V1][Core][1/n] Logging and Metrics (#11962)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat authored Jan 12, 2025
1 parent 263a870 commit 9597a09
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 84 deletions.
4 changes: 2 additions & 2 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_engine_core(monkeypatch):
assert len(engine_core.scheduler.running) == 4

# Loop through until they are all done.
while len(engine_core.step()) > 0:
while len(engine_core.step().outputs) > 0:
pass

assert len(engine_core.scheduler.waiting) == 0
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done.
while len(engine_core.step()) > 0:
while len(engine_core.step().outputs) > 0:
pass

assert len(engine_core.scheduler.waiting) == 0
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
def loop_until_done(client: EngineCoreClient, outputs: Dict):

while True:
engine_core_outputs = client.get_output()
engine_core_outputs = client.get_output().outputs

if len(engine_core_outputs) == 0:
break
Expand All @@ -61,7 +61,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):

while True:
engine_core_outputs = await client.get_output_async()
engine_core_outputs = await client.get_output_async().outputs

if len(engine_core_outputs) == 0:
break
Expand Down
20 changes: 15 additions & 5 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

Expand Down Expand Up @@ -394,12 +395,12 @@ def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[EngineCoreOutput]:
) -> EngineCoreOutputs:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
engine_core_outputs: List[EngineCoreOutput] = []
outputs: List[EngineCoreOutput] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
Expand Down Expand Up @@ -438,15 +439,18 @@ def update_from_output(
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
engine_core_outputs.append(output)
outputs.append(output)

# Breakout of the loop.
if stopped:
continue

new_running.append(request)
self.running = new_running
return engine_core_outputs
return EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
)

def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
Expand Down Expand Up @@ -515,6 +519,12 @@ def get_num_unfinished_requests(self) -> int:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0

def make_stats(self) -> SchedulerStats:
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
)


@dataclass
class NewRequestData:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import msgspec

from vllm.v1.metrics.stats import SchedulerStats

if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
Expand Down Expand Up @@ -56,6 +58,7 @@ class EngineCoreOutputs(

# [num_reqs]
outputs: List[EngineCoreOutput]
scheduler_stats: SchedulerStats


@dataclass
Expand Down
26 changes: 19 additions & 7 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
Expand All @@ -22,6 +21,8 @@
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase
from vllm.v1.metrics.stats import SchedulerStats

logger = init_logger(__name__)

Expand All @@ -34,7 +35,6 @@ def __init__(
executor_class: Type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
log_requests: bool = True,
Expand All @@ -45,7 +45,10 @@ def __init__(

self.log_requests = log_requests
self.log_stats = log_stats
self.stat_loggers = stat_loggers
self.stat_loggers: List[StatLoggerBase] = [
LoggingStatLogger(),
# TODO(rob): PrometheusStatLogger(),
]
self.model_config = vllm_config.model_config

# Tokenizer (+ ensure liveness if running in another process).
Expand Down Expand Up @@ -82,7 +85,6 @@ def __init__(
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)

self.output_handler: Optional[asyncio.Task] = None
Expand All @@ -94,7 +96,6 @@ def from_engine_args(
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""

Expand All @@ -114,7 +115,6 @@ def from_engine_args(
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)

def shutdown(self):
Expand Down Expand Up @@ -254,14 +254,18 @@ async def _run_output_handler(self):
outputs = await self.engine_core.get_output_async()

# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)
request_outputs, reqs_to_abort = self.detokenizer.step(
outputs.outputs)

# 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs)

# 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort)

# 5) Log any stats.
await self._log_stats(scheduler_stats=outputs.scheduler_stats)

except Exception as e:
logger.exception("EngineCore output handler hit an error: %s", e)
kill_process_tree(os.getpid())
Expand All @@ -278,6 +282,14 @@ async def abort(self, request_id: str) -> None:
if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id]

async def _log_stats(self, scheduler_stats: SchedulerStats):
"""Log stats to the stat loggers."""
if not self.log_stats:
return

for logger in self.stat_loggers:
logger.log(scheduler_stats=scheduler_stats)

def encode(
self,
prompt: PromptType,
Expand Down
53 changes: 16 additions & 37 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
Expand All @@ -28,9 +28,7 @@

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5
POLLING_TIMEOUT_S = 2.5


class EngineCore:
Expand All @@ -40,10 +38,8 @@ def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
):
assert vllm_config.model_config.runner_type != "pooling"
self.log_stats = log_stats

logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
Expand All @@ -62,8 +58,6 @@ def __init__(
vllm_config.cache_config,
vllm_config.lora_config)

self._last_logging_time = time.time()

self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)

Expand Down Expand Up @@ -114,11 +108,12 @@ def abort_requests(self, request_ids: List[str]):
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)

def step(self) -> List[EngineCoreOutput]:
def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output."""

if not self.scheduler.has_unfinished_requests():
return []
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())

scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
Expand All @@ -145,15 +140,17 @@ def __init__(
executor_class: Type[Executor],
log_stats: bool = False,
):
super().__init__(vllm_config, executor_class, log_stats)
super().__init__(vllm_config, executor_class)

self.log_stats = log_stats

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
daemon=True).start()
Expand Down Expand Up @@ -217,8 +214,10 @@ def run_busy_loop(self):
self._handle_client_request(req)
break
except queue.Empty:
self._log_stats()
logger.debug("EngineCore busy loop waiting.")
# Break out the loop so we can log_stats in step().
if self.log_stats:
break
except BaseException:
raise

Expand All @@ -230,28 +229,9 @@ def run_busy_loop(self):
# 3) Step the engine core.
outputs = self.step()

# 4) Put EngineCoreOutputs into the output queue.
# 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)

self._log_stats()

def _log_stats(self):
"""Log basic stats every LOGGING_TIME_S"""

if not self.log_stats:
return

now = time.time()

if now - self._last_logging_time > LOGGING_TIME_S:
logger.info(
"RUNNING: %s | WAITING: %s",
len(self.scheduler.running),
len(self.scheduler.waiting),
)

self._last_logging_time = now

def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""

Expand Down Expand Up @@ -301,7 +281,6 @@ def process_output_socket(self, output_path: str):

with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
outputs = self.output_queue.get()
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False)
Loading

0 comments on commit 9597a09

Please sign in to comment.