-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathutils.py
144 lines (124 loc) · 6.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import json
import sys
from argparse import Namespace
import torch
from model.llama import NormalLinear
import os
def load_hyperparam(default_args):
"""
Load arguments form argparse and config file
Priority: default options < config file < command line args
"""
with open(default_args.config_path, mode="r", encoding="utf-8") as f:
config_args_dict = json.load(f)
default_args_dict = vars(default_args)
command_line_args_dict = {k: default_args_dict[k] for k in [
a[2:] for a in sys.argv if (a[:2] == "--" and "local_rank" not in a)
]}
default_args_dict.update(config_args_dict)
default_args_dict.update(command_line_args_dict)
args = Namespace(**default_args_dict)
return args
def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""):
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = torch.load(model_path, map_location="cpu")
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
state_dict['target.lm.weight'] = state_dict['target.lm.output_layer.weight']
del state_dict['target.lm.output_layer.weight']
state_dict['embedding.embedding.weight'] = state_dict['embedding.word.embedding.weight']
del state_dict['embedding.word.embedding.weight']
if metadata is not None:
metadata['embedding.embedding'] = metadata['embedding.word.embedding']
metadata['target.lm'] = metadata['target.lm.output_layer']
if metadata.get('embedding.dropout', None) is not None:
del metadata['embedding.dropout']
del metadata['embedding.word']
del metadata['embedding.word.embedding']
del metadata['target.lm.output_layer']
del metadata['target.lm.softmax']
del metadata['target.lm.criterion']
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
import deepspeed
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return model_to_load
def convert_normal_parameter_to_int8(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
import bitsandbytes as bnb
modules_to_not_convert = ["lm"] if modules_to_not_convert is None else modules_to_not_convert
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if len(list(module.children())) > 0:
convert_normal_parameter_to_int8(module, threshold, modules_to_not_convert, current_key_name)
if isinstance(module, bnb.nn.Linear8bitLt) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
model._modules[name].weight = bnb.nn.Int8Params(
module.weight.data,
requires_grad=False,
has_fp16_weights=False
)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
current_key_name.pop(-1)
return model
def load_model(model, model_path):
if os.path.isdir(model_path):
index_filename = os.path.join(model_path, 'pytorch_model.bin.index.json')
with open(index_filename, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
shard_filenames = [os.path.join(model_path, f) for f in shard_filenames]
for shard_file in shard_filenames:
shard_checkpoint = torch.load(shard_file, map_location='cpu')
for name, parameter in model.named_parameters():
if shard_checkpoint.get(name, None) is not None:
if 'target' in name:
parameter.data = shard_checkpoint['target.lm.output_layer.weight']
elif 'embedding' in name:
parameter.data = shard_checkpoint['embedding.word.embedding.weight']
else:
parameter.data = shard_checkpoint[name]
parameter.requires_grad = False
del shard_checkpoint
else:
checkpoint = torch.load(model_path, map_location='cpu')
for parameter_name, parameter in model.named_parameters():
if 'target' in parameter_name:
parameter.data = checkpoint['target.lm.output_layer.weight']
elif 'embedding' in parameter_name:
parameter.data = checkpoint['embedding.word.embedding.weight']
else:
parameter.data = checkpoint[parameter_name]
parameter.requires_grad = False
del checkpoint
return model