-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel.py
65 lines (49 loc) · 1.29 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import logging
from dataclasses import dataclass
@dataclass(kw_only=True)
class ModelArchitecture():
name: str
num_layers: int
@dataclass(kw_only=True)
class LLMArchitecture(ModelArchitecture):
hidden_size: int
num_heads: int
@dataclass(kw_only=True)
class ModelParallelism():
"""
Captures the different parallelisms of a Model.
"""
pipeline_parallelism: int
tensor_parallelism: int
@property
def num_processors(self):
"""
The number of GPUs required is the product of the parallelisms.
"""
return self.pipeline_parallelism * self.tensor_parallelism
@dataclass(kw_only=True)
class ModelSize():
"""
Captures the various sizes of a Model.
"""
weights: int
dtype_size: int
@property
def total_size(self):
return self.weights
@dataclass(kw_only=True)
class Model():
name: str
architecture: ModelArchitecture
parallelism: ModelParallelism
size: ModelSize
@property
def size_per_processor(self):
return self.size.total_size / self.parallelism.num_processors
@dataclass(kw_only=True)
class GenerativeLLM(Model):
"""
Generative Large Language Model.
NOTE: We currently don't capture embeddings, variable context lengths, etc.
"""
context_size: int = 0