Skip to content

Commit

Permalink
Attempt import fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertTLange committed Apr 21, 2024
1 parent 3283f26 commit b928046
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
### [v0.1.7] - [05/2024]

##### Added

- Implements fully `pmap`-compatible implementations of `OpenES`, `PGPE`, `Sep_CMA_ES` and `SNES`.

### [v0.1.6] - [03/2024]

##### Added
Expand Down
29 changes: 11 additions & 18 deletions evosax/problems/control_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax.numpy as jnp
from typing import Optional
import chex
import gymnax


class GymnaxFitness(object):
Expand All @@ -19,12 +20,12 @@ def __init__(
self.num_rollouts = num_rollouts
self.test = test

try:
import gymnax
except ImportError:
raise ImportError(
"You need to install `gymnax` to use its fitness rollouts."
)
# try:

# except ImportError:
# raise ImportError(
# "You need to install `gymnax` to use its fitness rollouts."
# )

# Define the RL environment & replace default parameters if desired
self.env, self.env_params = gymnax.make(env_name, **env_kwargs)
Expand Down Expand Up @@ -68,14 +69,10 @@ def set_apply_fn(self, network_apply, carry_init=None):
else:
self.rollout_map = self.rollout_pop

def rollout_pmap(
self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree
):
def rollout_pmap(self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree):
"""Parallelize rollout across devices. Split keys/reshape correctly."""
keys_pmap = jnp.tile(rng_input, (self.n_devices, 1, 1))
rew_dev, steps_dev = jax.pmap(self.rollout_pop)(
keys_pmap, policy_params
)
rew_dev, steps_dev = jax.pmap(self.rollout_pop)(keys_pmap, policy_params)
rew_re = rew_dev.reshape(-1, self.num_rollouts)
steps_re = steps_dev.reshape(-1, self.num_rollouts)
return rew_re, steps_re
Expand All @@ -88,9 +85,7 @@ def rollout(self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree):
self.total_env_steps += masks.sum()
return scores

def rollout_ffw(
self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree
):
def rollout_ffw(self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree):
"""Rollout an episode with lax.scan."""
# Reset the environment
rng_reset, rng_episode = jax.random.split(rng_input)
Expand Down Expand Up @@ -136,9 +131,7 @@ def policy_step(state_input, tmp):
cum_return = carry_out[-2].squeeze()
return cum_return, jnp.array(ep_mask)

def rollout_rnn(
self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree
):
def rollout_rnn(self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree):
"""Rollout a jitted episode with lax.scan."""
# Reset the environment
rng, rng_reset = jax.random.split(rng_input)
Expand Down

0 comments on commit b928046

Please sign in to comment.