-
Notifications
You must be signed in to change notification settings - Fork 683
/
template_jobs.py
126 lines (107 loc) · 3.25 KB
/
template_jobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from examples import *
def batch_atari():
cf = Config()
cf.add_argument('--i', type=int, default=0)
cf.add_argument('--j', type=int, default=0)
cf.merge()
games = [
'BreakoutNoFrameskip-v4',
# 'AlienNoFrameskip-v4',
# 'DemonAttackNoFrameskip-v4',
# 'MsPacmanNoFrameskip-v4'
]
algos = [
dqn_pixel,
quantile_regression_dqn_pixel,
categorical_dqn_pixel,
rainbow_pixel,
a2c_pixel,
n_step_dqn_pixel,
option_critic_pixel,
ppo_pixel,
]
params = []
for game in games:
for r in range(1):
for algo in algos:
params.append([algo, dict(game=game, run=r, remark=algo.__name__)])
# for n_step in [1, 2, 3]:
# for double_q in [True, False]:
# params.extend([
# [dqn_pixel,
# dict(game=game, run=r, n_step=n_step, replay_cls=PrioritizedReplay, double_q=double_q,
# remark=dqn_pixel.__name__)],
# [rainbow_pixel,
# dict(game=game, run=r, n_step=n_step, noisy_linear=False, remark=rainbow_pixel.__name__)]
# ])
# params.append(
# [categorical_dqn_pixel, dict(game=game, run=r, remark=categorical_dqn_pixel.__name__)]),
# params.append([dqn_pixel, dict(game=game, run=r, remark=dqn_pixel.__name__)])
algo, param = params[cf.i]
algo(**param)
exit()
def batch_mujoco():
cf = Config()
cf.add_argument('--i', type=int, default=0)
cf.add_argument('--j', type=int, default=0)
cf.merge()
games = [
'dm-acrobot-swingup',
'dm-acrobot-swingup_sparse',
'dm-ball_in_cup-catch',
'dm-cartpole-swingup',
'dm-cartpole-swingup_sparse',
'dm-cartpole-balance',
'dm-cartpole-balance_sparse',
'dm-cheetah-run',
'dm-finger-turn_hard',
'dm-finger-spin',
'dm-finger-turn_easy',
'dm-fish-upright',
'dm-fish-swim',
'dm-hopper-stand',
'dm-hopper-hop',
'dm-humanoid-stand',
'dm-humanoid-walk',
'dm-humanoid-run',
'dm-manipulator-bring_ball',
'dm-pendulum-swingup',
'dm-point_mass-easy',
'dm-reacher-easy',
'dm-reacher-hard',
'dm-swimmer-swimmer15',
'dm-swimmer-swimmer6',
'dm-walker-stand',
'dm-walker-walk',
'dm-walker-run',
]
games = [
'HalfCheetah-v2',
'Walker2d-v2',
'Swimmer-v2',
'Hopper-v2',
'Reacher-v2',
'Ant-v2',
'Humanoid-v2',
'HumanoidStandup-v2',
]
params = []
for game in games:
if 'Humanoid' in game:
algos = [ppo_continuous]
else:
algos = [ppo_continuous, ddpg_continuous, td3_continuous]
for algo in algos:
for r in range(5):
params.append([algo, dict(game=game, run=r)])
algo, param = params[cf.i]
algo(**param, remark=algo.__name__)
exit()
if __name__ == '__main__':
mkdir('log')
mkdir('data')
random_seed()
# select_device(0)
# batch_atari()
select_device(-1)
batch_mujoco()