-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
103 lines (76 loc) · 3.45 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from libs import *
from mlp import MLP, Block
from loader import DataLoader
from attention import Scaled_DotProduct_Attention, MultiHeadAttention
class BigramModel(nn.Module):
def __init__(self, vocab_size, n_embd):
super(BigramModel, self).__init__()
num_heads = 4
batch_size = 32
self.eval_iters = 200
self.context_length = 10
self.linear = nn.Linear(n_embd, vocab_size)
self.mlp = MLP(n_embd)
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(self.context_length, n_embd)
self.process = DataLoader(self.context_length, batch_size, file='input.txt')
self.attention = MultiHeadAttention(n_embd, self.context_length, num_heads, head_size = n_embd // num_heads)
self.blocks = nn.Sequential(
Block(n_embd, self.context_length, num_heads),
Block(n_embd, self.context_length, num_heads),
Block(n_embd, self.context_length, num_heads)
)
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.context_length:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
def estimate_loss(self, model):
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(self.eval_iters)
for k in range(self.eval_iters):
x, y = self.process.load_batch(split)
logits, loss = self(x, y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
def trainer(self, model, max_iters=5000, eval_interval=100):
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for iter in range(max_iters):
if iter % eval_interval == 0:
losses = self.estimate_loss(model)
print(f'step {iter} : train loss {losses["train"]:.4f} val loss {losses["val"]:.4f}')
x, y = self.process.load_batch('train')
logits, loss = self(x, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
print(self.process.decode(self.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=1000)[0].tolist()))
def forward(self, idx, targets=None):
B, T = idx.shape
tokn_embds = self.token_embedding_table(idx)
position_embds = self.position_embedding_table(torch.arange(T))
concat_token = tokn_embds + position_embds
x = self.blocks(concat_token) # give some time for the tokens to process the information they've gathered regarding other tokens
logits = self.linear(x)
if targets is None: loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
process = DataLoader(8, 32, 'input.txt')
inputs, targets = process.load_batch('train')
vocab_size = process.getDataLength()
bigram = BigramModel(vocab_size=vocab_size, n_embd=32)
logits, loss = bigram(inputs, targets)
bigram.trainer(bigram)