Skip to content

Commit

Permalink
Eval every & TransitionBoatRace-v0 fixes (#61)
Browse files Browse the repository at this point in the history
* fixing evaluation check (cf. #60)

* make LR required

* add boat transition env to parsing code, clean the file

* moving ppo crmdp to proper place

* fixing shapes

* assume pip > 1.18.1, rm dependency_links
  • Loading branch information
jvmncs authored Feb 23, 2019
1 parent 6350d1b commit 305f84d
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 232 deletions.
2 changes: 0 additions & 2 deletions safe_grid_agents/common/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from safe_grid_agents.common.agents.value import TabularQAgent, DeepQAgent
from safe_grid_agents.common.agents.policy_mlp import PPOMLPAgent
from safe_grid_agents.common.agents.policy_cnn import PPOCNNAgent
from safe_grid_agents.common.agents.policy_crmdp import PPOCRMDPAgent

__all__ = [
"RandomAgent",
Expand All @@ -12,5 +11,4 @@
"DeepQAgent",
"PPOMLPAgent",
"PPOCNNAgent",
"PPOCRMDPAgent",
]
19 changes: 12 additions & 7 deletions safe_grid_agents/common/agents/policy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, env, args) -> None:
self.action_n = env.action_space.n
self.discount = args.discount
self.board_shape = env.observation_space.shape
self.n_input = self.board_shape[0] * self.board_shape[1]
self.n_input = self.board_shape[0] * self.board_shape[1] * self.board_shape[2]
self.device = args.device
self.log_gradients = args.log_gradients

Expand Down Expand Up @@ -64,15 +64,20 @@ def policy(self, state) -> Categorical:

def learn(self, states, actions, rewards, returns, history, args) -> History:
states = torch.as_tensor(states, dtype=torch.float, device=self.device)
actions = torch.as_tensor(actions, dtype=torch.long, device=self.device)
returns = torch.as_tensor(returns, dtype=torch.float, device=self.device)
rlsz = self.rollouts * states.size(1)
states = states.reshape(rlsz, states.shape[2], states.shape[3], states.shape[4])
actions = torch.as_tensor(
actions, dtype=torch.long, device=self.device
).reshape(rlsz, -1)
returns = torch.as_tensor(
returns, dtype=torch.float, device=self.device
).reshape(rlsz, -1)

for epoch in range(self.epochs):
rlsz = self.rollouts * states.size(1)
ixs = torch.randint(rlsz, size=(self.batch_size,), dtype=torch.long)
s = states.reshape(rlsz, states.shape[2], states.shape[3])[ixs]
a = actions.reshape(rlsz, -1)[ixs].reshape(-1)
r = returns.reshape(rlsz, -1)[ixs].reshape(-1)
s = states[ixs]
a = actions[ixs].reshape(-1)
r = returns[ixs].reshape(-1)

prepolicy, state_values = self(s)
state_values = state_values.reshape(-1)
Expand Down
23 changes: 15 additions & 8 deletions safe_grid_agents/common/agents/policy_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ def __init__(self, env, args) -> None:

def build_ac(self) -> None:
"""Build the fused actor-critic architecture."""
in_channels = self.board_shape[0]
first = nn.Sequential(
torch.nn.Conv2d(1, self.n_channels, kernel_size=3, stride=1, padding=1),
torch.nn.Conv2d(
in_channels, self.n_channels, kernel_size=3, stride=1, padding=1
),
nn.ReLU(),
)
hidden = nn.Sequential(
Expand All @@ -36,6 +39,9 @@ def build_ac(self) -> None:
)
)
self.network = nn.Sequential(first, hidden)
self.bottleneck = nn.Conv2d(
in_channels, self.n_channels, kernel_size=1, stride=1
)

self.actor_cnn = nn.Sequential(
torch.nn.Conv2d(
Expand All @@ -44,7 +50,8 @@ def build_ac(self) -> None:
nn.ReLU(),
)
self.actor_linear = nn.Linear(
self.n_input * (self.n_channels), int(self.action_n)
self.n_channels * self.board_shape[1] * self.board_shape[2],
int(self.action_n),
)

self.critic_cnn = nn.Sequential(
Expand All @@ -53,15 +60,15 @@ def build_ac(self) -> None:
),
nn.ReLU(),
)
self.critic_linear = nn.Linear(self.n_input * (self.n_channels), 1)
self.critic_linear = nn.Linear(
self.n_channels * self.board_shape[1] * self.board_shape[2], 1
)

def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
if len(x.shape) == 2:
x = x.reshape(1, 1, x.shape[0], x.shape[1])
elif len(x.shape) == 3:
x = x.unsqueeze(1)
if len(x.shape) == 3:
x = x.unsqueeze(0)

convolutions = self.network(x) + x
convolutions = self.network(x) + self.bottleneck(x)

actor = self.actor_cnn(convolutions)
actor = actor.reshape(actor.shape[0], -1)
Expand Down
202 changes: 0 additions & 202 deletions safe_grid_agents/common/agents/policy_crmdp.py

This file was deleted.

2 changes: 1 addition & 1 deletion safe_grid_agents/common/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def ppo_learn(agent, env, env_state, history, args):
agent.sync()

# Check for evaluating next
if history["episode"] % args.eval_every == args.eval_every - 1:
if history["episode"] % args.eval_every == 0 and history["episode"] > 0:
eval_next = True

return env_state, history, eval_next
Expand Down
1 change: 1 addition & 0 deletions safe_grid_agents/parsing/agent_parser_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tabular-q:
lr: &learnrate
alias: l
type: float
required: true
help: "Learning rate (required)"
epsilon: &epsilon
alias: e
Expand Down
1 change: 1 addition & 0 deletions safe_grid_agents/parsing/env_parser_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ tomato-crmdp:
whisky:
corners:
way:
trans-boat:
11 changes: 3 additions & 8 deletions safe_grid_agents/parsing/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,17 @@

import yaml

from ai_safety_gridworlds.environments.boat_race import BoatRaceEnvironment
from ai_safety_gridworlds.environments.side_effects_sokoban import (
SideEffectsSokobanEnvironment,
)
from ai_safety_gridworlds.environments.tomato_crmdp import TomatoCRMDPEnvironment
from ai_safety_gridworlds.environments.tomato_watering import TomatoWateringEnvironment
from safe_grid_agents.common.agents import (
DeepQAgent,
PPOCNNAgent,
PPOCRMDPAgent,
PPOMLPAgent,
RandomAgent,
SingleActionAgent,
TabularQAgent,
)
from safe_grid_agents.parsing import agent_config, core_config, env_config
from safe_grid_agents.ssrl import TabularSSQAgent
from safe_grid_agents.spiky.agents import PPOCRMDPAgent
from safe_grid_agents.ssrl.agents import TabularSSQAgent


# Mapping of envs/agents to Python classes
Expand All @@ -39,6 +33,7 @@
"whisky": "WhiskyGold-v0",
"corners": "ToyGridworldCorners-v0",
"way": "ToyGridworldOnTheWay-v0",
"trans-boat": "TransitionBoatRace-v0",
}

AGENT_MAP = { # Dict[AgentName, Agent]
Expand Down
1 change: 0 additions & 1 deletion safe_grid_agents/ssrl/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from safe_grid_agents.ssrl.agents import TabularSSQAgent
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@
"rl "
"reinforcement learning "
),
install_requires=["safe-grid-gym", "pyyaml", "moviepy", "tensorboardX<=1.5", "ray"],
dependency_links=[
"https://github.com/david-lindner/safe-grid-gym/tarball/master#egg=safe-grid-gym-0.2"
install_requires=[
"safe-grid-gym @ git+https://github.com/david-lindner/safe-grid-gym.git",
"pyyaml",
"moviepy",
"tensorboardX<=1.5",
"ray",
],
packages=setuptools.find_packages(),
zip_safe=True,
Expand Down

0 comments on commit 305f84d

Please sign in to comment.