-
Notifications
You must be signed in to change notification settings - Fork 98
/
Copy pathhpo.py
79 lines (66 loc) · 2.02 KB
/
hpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#TODO: Import your dependencies.
#For instance, below are some dependencies you might need if you are using Pytorch
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import argparse
def test(model, test_loader):
'''
TODO: Complete this function that can take a model and a
testing data loader and will get the test accuray/loss of the model
Remember to include any debugging/profiling hooks that you might need
'''
pass
def train(model, train_loader, criterion, optimizer):
'''
TODO: Complete this function that can take a model and
data loaders for training and will get train the model
Remember to include any debugging/profiling hooks that you might need
'''
pass
def net():
'''
TODO: Complete this function that initializes your model
Remember to use a pretrained model
'''
pass
def create_data_loaders(data, batch_size):
'''
This is an optional function that you may or may not need to implement
depending on whether you need to use data loaders or not
'''
pass
def main(args):
'''
TODO: Initialize a model by calling the net function
'''
model=net()
'''
TODO: Create your loss and optimizer
'''
loss_criterion = None
optimizer = None
'''
TODO: Call the train function to start training your model
Remember that you will need to set up a way to get training data from S3
'''
model=train(model, train_loader, loss_criterion, optimizer)
'''
TODO: Test the model to see its accuracy
'''
test(model, test_loader, criterion)
'''
TODO: Save the trained model
'''
torch.save(model, path)
if __name__=='__main__':
parser=argparse.ArgumentParser()
'''
TODO: Specify all the hyperparameters you need to use to train your model.
'''
args=parser.parse_args()
main(args)