-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcitation.py
37 lines (26 loc) · 1.1 KB
/
citation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
from fold_data import c_dataset_loader
from config import opts
from train import train
from test import test
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default='cora', help="training dataset name [\'cora\',\'citeseer\',\'pubmed\']")
parser.add_argument("--lr", type=float, default=5e-3, help="learning rate")
parser.add_argument("--dropout", type=float, default=0.6, help="dropout, default=0.6")
parser.add_argument("--epoch", type=int, default=800, help="training epoch")
parser.add_argument("--seed", type=int, default=100, help="random initializing seed")
if __name__ == '__main__':
args = parser.parse_args()
opt = opts()
opt.dataset = args.dataset
opt.lr = args.lr
opt.drop_out = args.dropout
opt.epoch = args.epoch
opt.np_random_seed = args.seed
# load data
data_loader = c_dataset_loader(opt.dataset, opt.data_path)
adj, feat, label, idx_train, idx_val, idx_test = data_loader.process_data()
# Train model
train(adj, feat, label, idx_train, idx_val, opt)
# Test
test(adj, feat, label, idx_test, opt)