forked from EnJiang/USV-Game
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpolicy.py
49 lines (38 loc) · 1.38 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import random
from rl.policy import EpsGreedyQPolicy
class Policy(object):
def __init__(self, world):
self.world = world
def action(self, obs):
raise NotImplementedError()
class TestPolicy(Policy):
def action(self, obs):
return random.choice(obs)
class EpsGreedyQPolicyWithGuide(EpsGreedyQPolicy):
def __init__(self, world, eps=.1):
super(EpsGreedyQPolicyWithGuide, self).__init__()
self.eps = eps
self.world = world
def select_action(self, q_values):
"""Return the selected action
# Arguments
q_values (np.ndarray): List of the estimations of Q for each action
# Returns
Selection action
"""
assert q_values.ndim == 1
nb_actions = q_values.shape[0]
if np.random.uniform() < self.eps:
try:
a_star_action = self.world.policy_agents[0].finda()
a_star_action_i = self.world.action_space.index(a_star_action)
except:
a_star_action_i = None
if random.random() < 0.2 or a_star_action_i is None:
action = np.random.random_integers(0, nb_actions - 1)
else:
action = a_star_action_i
else:
action = np.argmax(q_values)
return action