-
Notifications
You must be signed in to change notification settings - Fork 143
/
Copy pathconvert_model_into_blocks.py
48 lines (36 loc) · 1.49 KB
/
convert_model_into_blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import collections
import torch
import os
import json
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str,
help="Input file path")
parser.add_argument("--output_model_path", type=str,
help="Output folder path")
parser.add_argument("--block_size", type=int, default=10,
help="Disi size (GB) of each block.")
args = parser.parse_args()
os.system('mkdir ' + args.output_model_path)
input_model = torch.load(args.input_model_path)
byte_size = args.block_size * 500000000
param_count, file_count, filename_count = 0, 0, 0
index_dict = {"weight_map": {}}
state_dict = collections.OrderedDict()
filename = f"tencentpretrain_model-0.bin"
for k, v in input_model.items():
state_dict[k] = v
index_dict["weight_map"][k] = filename
param_count += v.numel()
file_count += v.numel()
if file_count > byte_size:
torch.save(state_dict, os.path.join(args.output_model_path, filename))
state_dict = collections.OrderedDict()
filename_count += 1
filename = f"tencentpretrain_model-"+str(filename_count)+".bin"
file_count = 0
if len(state_dict) > 0:
torch.save(state_dict, os.path.join(args.output_model_path, filename))
index_dict["metadata"] = {"total_size": param_count * 2}
with open(os.path.join(args.output_model_path, "tencentpretrain_model.bin.index.json"), "w") as f:
json.dump(index_dict, f)