Skip to content

Commit

Permalink
Merge pull request vllm-project#8 from Snowflake-Labs/remove-dummy
Browse files Browse the repository at this point in the history
remove dummy path in arctic
  • Loading branch information
sfc-gh-aqiao authored May 1, 2024
2 parents ad00b54 + 6aa6de3 commit 15de0c2
Showing 1 changed file with 30 additions and 34 deletions.
64 changes: 30 additions & 34 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,53 +481,49 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

params_dict = dict(self.named_parameters())

if use_dummy:
logger.info("Using dummy weights. Skip loading weights.")
else:
logger.info(
"It takes ~10 minutes to load the weights. Please be patient.")
for name, loaded_weight in weights:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
logger.info(
"It takes ~10 minutes to load the weights. Please be patient.")
for name, loaded_weight in weights:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, shard_id in mlp_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, shard_id in mlp_params_mapping:
for param_name, weight_name, shard_id \
in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
weight_loader(param,
loaded_weight,
weight_name,
expert_id=shard_id)
break
else:
for param_name, weight_name, shard_id \
in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=shard_id)
break
else:
if name.endswith(
".bias") and name not in params_dict:
continue
param = params_dict[name]

weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.endswith(
".bias") and name not in params_dict:
continue
param = params_dict[name]

weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

0 comments on commit 15de0c2

Please sign in to comment.