Skip to content

Commit

Permalink
Merge pull request #6 from sfc-gh-aqiao/arctic-prr
Browse files Browse the repository at this point in the history
separate out sharded state loader
  • Loading branch information
aurickq authored May 8, 2024
2 parents 462df46 + 319104a commit 7d84bb0
Show file tree
Hide file tree
Showing 6 changed files with 0 additions and 195 deletions.
67 changes: 0 additions & 67 deletions examples/save_sharded_state.py

This file was deleted.

1 change: 0 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,6 @@ class LoadFormat(str, enum.Enum):
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"


@dataclass
Expand Down
10 changes: 0 additions & 10 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,6 @@ def remove_lora(self, lora_id: int) -> bool:
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")

def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_sharded_state",
path=path,
pattern=pattern,
max_size=max_size)

@abstractmethod
def _run_workers(
self,
Expand Down
91 changes: 0 additions & 91 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,94 +347,6 @@ def load_model(self, *, model_config: ModelConfig,
vision_language_config)


class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_states.py` for creating a sharded checkpoint.
"""

DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}")

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
from safetensors.torch import load_file

from vllm.distributed import get_tensor_model_parallel_rank
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
model_config.model,
self.pattern.format(rank=rank, part="*"),
)
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!"
)
state_dict = dict(model.state_dict())
for path in filepaths:
for key, val in load_file(path).items():
state_dict[key].copy_(val)
state_dict.pop(key)
assert not state_dict
return model.eval()

@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from safetensors.torch import save_file

from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part = 0
total_size = 0
state_dict: Dict[str, torch.Tensor] = {}
for name, tensor in model.state_dict().items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
save_file(
state_dict,
os.path.join(path, pattern.format(rank=rank, part=part)),
)
part += 1
total_size = 0
state_dict = {}
state_dict[name] = tensor
total_size += param_size
if len(state_dict) > 0:
save_file(
state_dict,
os.path.join(path, pattern.format(rank=rank, part=part)),
)


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""

Expand All @@ -447,7 +359,4 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)

if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)

return DefaultModelLoader(load_config)
14 changes: 0 additions & 14 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,6 @@ def load_model(self) -> None:
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model,
path,
pattern=pattern,
max_size=max_size,
)

def set_block_size(self, block_size: int) -> None:
self.block_size = block_size

Expand Down
12 changes: 0 additions & 12 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,6 @@ def init_device(self) -> None:
def load_model(self):
self.model_runner.load_model()

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_runner.save_sharded_state(
path,
pattern=pattern,
max_size=max_size,
)

@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
Expand Down

0 comments on commit 7d84bb0

Please sign in to comment.