-
Notifications
You must be signed in to change notification settings - Fork 14
/
main.py
57 lines (44 loc) · 1.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# © Copyright IBM Corp. 2019
import pickle
from deltaencoder import DeltaEncoder
########### Load Data ################
features_train, labels_train, features_test, labels_test, episodes_1shot, episodes_5shot = pickle.load(open('data/mIN.pkl','rb'))
# features_train/features_test are features extracted from some backbone (resnet18); they are np array with size = (N,D), where N is the number of samples and D the features dimensions
# labels_train/labels_test are one hot GT labels with size = (N,C), where C is the number of classes (can be different for train and test sets
# episodes_*shot are supplied for reproduction of the paper results size=(num_episodes, num_classes, num_shots, D)
######### 1-shot Experiment #########
args = {'data_set' : 'mIN',
'num_shots' : 1,
'num_epoch': 6,
'nb_val_loop': 10,
'learning_rate': 1e-5,
'drop_out_rate': 0.5,
'drop_out_rate_input': 0.0,
'batch_size': 128,
'noise_size' : 16,
'nb_img' : 1024,
'num_ways' : 5,
'encoder_size' : [8192],
'decoder_size' : [8192],
'opt_type': 'adam'
}
model = DeltaEncoder(args, features_train, labels_train, features_test, labels_test, episodes_1shot)
model.train(verbose=False)
######### 5-shot Experiment #########
args = {'data_set' : 'mIN',
'num_shots' : 5,
'num_epoch': 12,
'nb_val_loop': 10,
'learning_rate': 1e-5,
'drop_out_rate': 0.5,
'drop_out_rate_input': 0.0,
'batch_size': 128,
'noise_size' : 16,
'nb_img' : 1024,
'num_ways' : 5,
'encoder_size' : [8192],
'decoder_size' : [8192],
'opt_type': 'adam'
}
model = DeltaEncoder(args, features_train, labels_train, features_test, labels_test, episodes_5shot)
model.train(verbose=False)