Skip to content

Commit

Permalink
Merge pull request chainer#7049 from toslunar/bp-6624-example-reinfor…
Browse files Browse the repository at this point in the history
…ce-dtype

[backport] Fix reinforcement_learning example to work with default dtype
  • Loading branch information
takagi authored May 7, 2019
2 parents 7e07f4d + 360bf9d commit d3c55d9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
16 changes: 9 additions & 7 deletions examples/reinforcement_learning/ddpg_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def forward(self, x):

def get_action(policy, obs):
"""Get an action by evaluating a given policy."""
obs = policy.xp.asarray(obs[None], dtype=np.float32)
dtype = chainer.get_dtype()
obs = policy.xp.asarray(obs[None], dtype=dtype)
with chainer.no_backprop_mode():
action = policy(obs).array[0]
return chainer.backends.cuda.to_cpu(action)
Expand All @@ -78,12 +79,13 @@ def get_action(policy, obs):
def update(Q, target_Q, policy, target_policy, opt_Q, opt_policy,
samples, gamma=0.99):
"""Update a Q-function and a policy."""
dtype = chainer.get_dtype()
xp = Q.xp
obs = xp.asarray([sample[0] for sample in samples], dtype=np.float32)
action = xp.asarray([sample[1] for sample in samples], dtype=np.float32)
reward = xp.asarray([sample[2] for sample in samples], dtype=np.float32)
done = xp.asarray([sample[3] for sample in samples], dtype=np.float32)
obs_next = xp.asarray([sample[4] for sample in samples], dtype=np.float32)
obs = xp.asarray([sample[0] for sample in samples], dtype=dtype)
action = xp.asarray([sample[1] for sample in samples], dtype=dtype)
reward = xp.asarray([sample[2] for sample in samples], dtype=dtype)
done = xp.asarray([sample[3] for sample in samples], dtype=dtype)
obs_next = xp.asarray([sample[4] for sample in samples], dtype=dtype)

def update_Q():
# Predicted values: Q(s,a)
Expand Down Expand Up @@ -194,7 +196,7 @@ def main():
policy.to_device(device)
target_Q = copy.deepcopy(Q)
target_policy = copy.deepcopy(policy)
opt_Q = optimizers.Adam()
opt_Q = optimizers.Adam(eps=1e-5) # Use larger eps in case of FP16 mode
opt_Q.setup(Q)
opt_policy = optimizers.Adam(alpha=1e-4)
opt_policy.setup(policy)
Expand Down
12 changes: 7 additions & 5 deletions examples/reinforcement_learning/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def forward(self, x):

def get_greedy_action(Q, obs):
"""Get a greedy action wrt a given Q-function."""
obs = Q.xp.asarray(obs[None], dtype=np.float32)
dtype = chainer.get_dtype()
obs = Q.xp.asarray(obs[None], dtype=dtype)
with chainer.no_backprop_mode():
q = Q(obs).array[0]
return int(q.argmax())
Expand All @@ -50,12 +51,13 @@ def mean_clipped_loss(y, t):

def update(Q, target_Q, opt, samples, gamma=0.99, target_type='double_dqn'):
"""Update a Q-function with given samples and a target Q-function."""
dtype = chainer.get_dtype()
xp = Q.xp
obs = xp.asarray([sample[0] for sample in samples], dtype=np.float32)
obs = xp.asarray([sample[0] for sample in samples], dtype=dtype)
action = xp.asarray([sample[1] for sample in samples], dtype=np.int32)
reward = xp.asarray([sample[2] for sample in samples], dtype=np.float32)
done = xp.asarray([sample[3] for sample in samples], dtype=np.float32)
obs_next = xp.asarray([sample[4] for sample in samples], dtype=np.float32)
reward = xp.asarray([sample[2] for sample in samples], dtype=dtype)
done = xp.asarray([sample[3] for sample in samples], dtype=dtype)
obs_next = xp.asarray([sample[4] for sample in samples], dtype=dtype)
# Predicted values: Q(s,a)
y = F.select_item(Q(obs), action)
# Target values: r + gamma * max_b Q(s',b)
Expand Down

0 comments on commit d3c55d9

Please sign in to comment.