-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlearn.py
113 lines (85 loc) · 2.99 KB
/
learn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Agent-specific learning interactions."""
import functools
from typing import Callable
from safe_grid_agents.common import utils as ut
def whiler(f: Callable) -> Callable:
"""Evaluate the agent-specific learn function `f` inside of a generic
while loop."""
@functools.wraps(f)
def stepbystep(agent, env, env_state, history, args):
done = False
eval_next = False
while not done:
env_state, history = f(agent, env, env_state, history, args)
done = env_state[2]
history["t"] += 1
history = ut.track_metrics(history, env)
if history["episode"] % args.eval_every == args.eval_every - 1:
eval_next = True
return env_state, history, eval_next
return stepbystep
@whiler
def dqn_learn(agent, env, env_state, history, args):
"""Learning loop for DeepQAgent."""
state, reward, done, info = env_state
t = history["t"]
# Act
action = agent.act_explore(state)
successor, reward, done, info = env.step(action)
# Learn
if args.cheat:
reward = info["hidden_reward"]
# In case the agent is drunk, use the actual action they took
try:
action = info["extra_observations"]["actual_actions"]
except KeyError:
pass
history = agent.learn(state, action, reward, successor, done, history)
# Modify exploration
eps = agent.update_epsilon()
history["writer"].add_scalar("Train/epsilon", eps, t)
# Sync target and policy networks
if t % args.sync_every == args.sync_every - 1:
agent.sync_target_Q()
return (successor, reward, done, info), history
@whiler
def tabq_learn(agent, env, env_state, history, args):
"""Learning loop for TabularQAgent."""
state, reward, done, info = env_state
t = history["t"]
# Act
action = agent.act_explore(state)
successor, reward, done, info = env.step(action)
# Learn
if args.cheat:
reward = info["hidden_reward"]
# In case the agent is drunk, use the actual action they took
try:
action = info["extra_observations"]["actual_actions"]
except KeyError:
pass
agent.learn(state, action, reward, successor)
# Modify exploration
eps = agent.update_epsilon()
history["writer"].add_scalar("Train/epsilon", eps, t)
return (successor, reward, done, info), history
def ppo_learn(agent, env, env_state, history, args):
"""Learning loop for PPOAgent."""
eval_next = False
# Act
rollout = agent.gather_rollout(env, env_state, history, args)
# Learn
history = agent.learn(*rollout, history, args)
# Sync old and current policy
agent.sync()
# Check for evaluating next
if history["episode"] % args.eval_every == 0 and history["episode"] > 0:
eval_next = True
return env_state, history, eval_next
LEARN_MAP = {
"deep-q": dqn_learn,
"tabular-q": tabq_learn,
"ppo-mlp": ppo_learn,
"ppo-cnn": ppo_learn,
"ppo-crmdp": ppo_learn,
}