-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
90 lines (76 loc) · 2.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gym
from agent import Agent
import numpy as np
from utils import plot_learning_curve
ENV_NAME = "CartPole-v0"
ENV_PARAMS = {}
EPISODES = 500
LEARNING_RATE = 3e-3
GAMMA = 0.99
EPSILON_START = 1.0
EPSILON_DECAY = 9e-5
EPSILON_END = 1e-2
BATCH_SIZE = 64
MAX_MEMORY_SIZE = 100000
RENDER_TRAINING = False
RENDER_FPS = 60
ONLINE_TO_TARGET_FREQUENCY = 1000
SAVE_FREQUENCY = 10000
# Edit network at dqn.py
if __name__ == "__main__":
if RENDER_TRAINING:
env = gym.make(ENV_NAME, render_mode="human", **ENV_PARAMS)
env.metadata["render_fps"] = RENDER_FPS
else:
env = gym.make(ENV_NAME, **ENV_PARAMS)
obs, info = env.reset()
try:
agent = Agent.load()
print(f"[INFO] Resuming previous trainment")
except:
agent = Agent(
n_actions=env.action_space.n,
input_dims=env.observation_space.shape,
lr=LEARNING_RATE,
gamma=GAMMA,
epsilon=EPSILON_START,
eps_end=EPSILON_END,
eps_dec=EPSILON_DECAY,
batch_size=BATCH_SIZE,
max_mem_size=MAX_MEMORY_SIZE,
online_to_target_frequency=ONLINE_TO_TARGET_FREQUENCY,
save_frequency=SAVE_FREQUENCY,
)
scores, eps_history = [], []
for i in range(EPISODES):
score = 0
done = False
observation, _ = env.reset()
while not done:
action = agent.choose_action(observation)
observation_, reward, done, truncated, _ = env.step(action)
done = done or truncated
score += reward
agent.store_transition(observation, action, reward, observation_, done)
agent.learn()
observation = observation_
scores.append(score)
eps_history.append(agent.epsilon)
avg_score = np.mean(scores[-100:])
print(
f"Epsisode [{i}], Score [{score:.2f}], Avg Score [{avg_score:.2f}], Epsilon [{agent.epsilon:.3f}]",
)
agent.save()
plot_learning_curve(scores, eps_history)
env = gym.make(ENV_NAME, render_mode="human", **ENV_PARAMS)
while True:
score = 0
done = False
observation, _ = env.reset()
while not done:
action = agent.choose_action(observation)
observation_, reward, done, truncated, _ = env.step(action)
done = done or truncated
score += reward
observation = observation_
print(f"Test score [{score}]")