-
Notifications
You must be signed in to change notification settings - Fork 6
/
train_eval.py
57 lines (47 loc) · 1.52 KB
/
train_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
'''
Descripttion:
Version: 1.0
Author: ZhangHongYu
Date: 2022-03-26 18:50:20
LastEditors: ZhangHongYu
LastEditTime: 2022-03-26 19:52:39
'''
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
def train_op(model, loader, optimizer, epochs=1):
model.train()
for ep in range(epochs):
running_loss, samples = 0.0, 0
for x, y in loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
#自带softmax和默认mean
# 对于lstm模型
# model(x):(128, 100, 80)
# y: (128, 80)
loss = torch.nn.CrossEntropyLoss()(model(x), y)
running_loss += loss.item()*y.shape[0]
samples += y.shape[0]
loss.backward()
optimizer.step()
return running_loss / samples
def eval_op(model, loader):
model.train()
samples, correct = 0, 0
with torch.no_grad():
for i, (x, y) in enumerate(loader):
x, y = x.to(device), y.to(device)
y_ = model(x)
_, predicted = torch.max(y_.data, 1)
# predicted: (128, 80)
if len(y.shape) == 2:
# next ch pred处理任务
samples += y.shape[0] * y.shape[1]
else:
samples += y.shape[0]
correct += (predicted == y).sum().item()
# 可能部分client生成的样本数为0
if samples == 0:
return -1
else:
return correct/samples