From 319104aa358b91277808dc6f2f96c37ee80ece0c Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 8 May 2024 15:01:11 -0400 Subject: [PATCH] separate out sharded state loader --- examples/save_sharded_state.py | 67 ---------------- vllm/config.py | 1 - vllm/executor/distributed_gpu_executor.py | 10 --- vllm/model_executor/model_loader/loader.py | 91 ---------------------- vllm/worker/model_runner.py | 14 ---- vllm/worker/worker.py | 12 --- 6 files changed, 195 deletions(-) delete mode 100644 examples/save_sharded_state.py diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py deleted file mode 100644 index a3a19d8c6226d..0000000000000 --- a/examples/save_sharded_state.py +++ /dev/null @@ -1,67 +0,0 @@ -import argparse -import dataclasses -import os -import shutil -from pathlib import Path - -from vllm import LLM, EngineArgs - -""" -Saves each worker's model state dict directly to a checkpoint, which enables a -fast load path for large tensor-parallel models where each worker only needs to -read its own shard rather than the entire checkpoint. - -Example usage: - -python save_sharded_state.py \ - --model /path/to/load \ - --quantization deepspeedfp \ - --tensor-parallel-size 8 \ - --output /path/to/save - -Then, the model can be loaded with - -llm = LLM( - model="/path/to/save", - load_format="sharded_state", - quantization="deepspeedfp", - tensor_parallel_size=8, -) -""" - -parser = argparse.ArgumentParser() -EngineArgs.add_cli_args(parser) -parser.add_argument("--output", - "-o", - required=True, - type=str, - help="path to output checkpoint") -parser.add_argument("--pattern", - type=str, - help="string pattern of saved filenames") - - -def main(args): - engine_args = EngineArgs.from_cli_args(args) - model_path = engine_args.model - if not Path(model_path).is_dir(): - raise ValueError("model path must be a local directory") - # Create LLM instance from arguments - llm = LLM(**dataclasses.asdict(engine_args)) - # Prepare output directory - Path(args.output).mkdir(exist_ok=True) - # Dump worker states to output directory - model_executor = llm.llm_engine.model_executor - model_executor.save_sharded_state(path=args.output, - pattern=args.pattern, - max_size=5 * 1024**3) - # Copy metadata files to output directory - for file in os.listdir(model_path): - if not any( - file.endswith(ext) for ext in (".bin", ".pt", ".safetensors")): - shutil.copy(f"{model_path}/{file}", args.output) - - -if __name__ == "__main__": - args = parser.parse_args() - main(args) diff --git a/vllm/config.py b/vllm/config.py index 82165dedd43bf..6c65bbe247f84 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -455,7 +455,6 @@ class LoadFormat(str, enum.Enum): NPCACHE = "npcache" DUMMY = "dummy" TENSORIZER = "tensorizer" - SHARDED_STATE = "sharded_state" @dataclass diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 86f783cf6aff0..4c922ef63ee04 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -77,16 +77,6 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self._run_workers("list_loras") - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - self._run_workers("save_sharded_state", - path=path, - pattern=pattern, - max_size=max_size) - @abstractmethod def _run_workers( self, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 614204ce2bc28..bafa2de62e5df 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -347,94 +347,6 @@ def load_model(self, *, model_config: ModelConfig, vision_language_config) -class ShardedStateLoader(BaseModelLoader): - """ - Model loader that directly loads each worker's model state dict, which - enables a fast load path for large tensor-parallel models where each worker - only needs to read its own shard rather than the entire checkpoint. See - `examples/save_sharded_states.py` for creating a sharded checkpoint. - """ - - DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" - - def __init__(self, load_config: LoadConfig): - super().__init__(load_config) - extra_config = ({} if load_config.model_loader_extra_config is None - else load_config.model_loader_extra_config.copy()) - self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) - if extra_config: - raise ValueError(f"Unexpected extra config keys for load format " - f"{load_config.load_format}: " - f"{load_config.model_loader_extra_config.keys()}") - - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: - from safetensors.torch import load_file - - from vllm.distributed import get_tensor_model_parallel_rank - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) - rank = get_tensor_model_parallel_rank() - pattern = os.path.join( - model_config.model, - self.pattern.format(rank=rank, part="*"), - ) - filepaths = glob.glob(pattern) - if not filepaths: - # TODO: support un-sharded checkpoints too - raise ValueError( - f"Could not find checkpoint files '{pattern}', only " - f"pre-sharded checkpoints are currently supported!" - ) - state_dict = dict(model.state_dict()) - for path in filepaths: - for key, val in load_file(path).items(): - state_dict[key].copy_(val) - state_dict.pop(key) - assert not state_dict - return model.eval() - - @staticmethod - def save_model( - model: torch.nn.Module, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - from safetensors.torch import save_file - - from vllm.distributed import get_tensor_model_parallel_rank - if pattern is None: - pattern = ShardedStateLoader.DEFAULT_PATTERN - rank = get_tensor_model_parallel_rank() - part = 0 - total_size = 0 - state_dict: Dict[str, torch.Tensor] = {} - for name, tensor in model.state_dict().items(): - param_size = tensor.nelement() * tensor.element_size() - if max_size is not None and total_size + param_size > max_size: - save_file( - state_dict, - os.path.join(path, pattern.format(rank=rank, part=part)), - ) - part += 1 - total_size = 0 - state_dict = {} - state_dict[name] = tensor - total_size += param_size - if len(state_dict) > 0: - save_file( - state_dict, - os.path.join(path, pattern.format(rank=rank, part=part)), - ) - - def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -447,7 +359,4 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.TENSORIZER: return TensorizerLoader(load_config) - if load_config.load_format == LoadFormat.SHARDED_STATE: - return ShardedStateLoader(load_config) - return DefaultModelLoader(load_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 142b37cf3df9d..ab248596490f6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -212,20 +212,6 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - from vllm.model_executor.model_loader.loader import ShardedStateLoader - ShardedStateLoader.save_model( - self.model, - path, - pattern=pattern, - max_size=max_size, - ) - def set_block_size(self, block_size: int) -> None: self.block_size = block_size diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 75a7c7ad7b3c0..4add36e94f723 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,18 +117,6 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - self.model_runner.save_sharded_state( - path, - pattern=pattern, - max_size=max_size, - ) - @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many