-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy patheval_policy.py
63 lines (42 loc) · 1.91 KB
/
eval_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
This file helps the testing with evaluating the policy
"""
#--------------------------------- Rollout -----------------------------------#
# running the simulation to evaluate the policy
def rollout(policy, env, render):
while True:
# reset the environment for new episode
obs, _ = env.reset()
done = False
# initialize the length of the episode
t = 0
ep_len = 0
# initialize the summation of the reward in the episode
ep_rew = 0
while not done:
t += 1 # calculate the length of the episode
# render the environment
if render:
env.render()
# The action generated by the policy
action = policy(obs).detach().numpy()
# observation, reward and done from simulation
obs, rew, done, _, _ = env.step(action)
# the summation of the reward in the episode
ep_rew += rew
# calculate the length of the episode
ep_len = t
# return the length of total reward of the episode:
yield ep_len, ep_rew
#------------------------------ logger summary -------------------------------#
def _log_summary(ep_len, ep_ret, ep_num):
print(flush = True)
print(f"-------------------- episode #{ep_num+1} --------------------", flush=True)
print(f"Total cost: {str(ep_ret)}", flush=True)
print(f"----------------------------------------------------", flush=True)
print(flush=True)
#---------------------------- Policy evaluation ------------------------------#
def eval_policy(policy, env, render= False):
for ep_num, (ep_len, ep_ret) in enumerate(rollout(policy, env, render)):
_log_summary(ep_len=ep_len, ep_ret=ep_ret, ep_num=ep_num)
#-----------------------------------------------------------------------------#