-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
v0.2.0: Policy, Value Functions, and more Algorithms (#39)
* Improve support for vectorized environments. * Remove some hacks. * Add vectorized experience replay, many small improvements. * Fix lint. * Fix runner tests and Makefile. * Bring Fixes from Jumping (#40) * Update parameters for NatureFeatures. * Update TRPO for non-contiguous tensors. * More fixes. * Clean fix for change of memory layout in PT 1.6.0 * Add mean and mode to TanhNormal distribution. * Add device and device_env to Torch wrapper. * Fix device when flattening a replay. * Remove pdb from logger. * ExpReplay: implement getstate and setstate * Add test for replay.flatten(), faster implementation. * Fix replay test. * Update reward normalizer. * More fixes for RewardNormalizer. * Reward normalizer handles vectorized envs. * Accelerate creation of Transitions. * Cosmetics * Add wandb to pybullet. * Add nsteps samplgin to experience replay + tests. * Add nn.Policy (needs docs) and Normal, Categorical distributions (need docs). * Add line_search to trpo. * Add algorithm arguments, update LinearValue lstsq. * Add device to Transition fields. * Do not assume default device in Torch wrapper. * Add cherry.nn.MLP * Remove gym.wrappers.Monitor. * Add Closer, make compatible with gym 0.23.0 * Add DrQ and SAC updates. * Fix unit tests. * Fix Transition.device for vectorized envs. * Update algorithms. * tests: Fix MemorizeDigits registration. * tests: Fix actor-critic integration indentation. * Add DrQv2. * Add initial ActionValue API. * Fix robotics init, Twin action value. * Add nn.Lambda. * Remove std reduction in DrQv2. * Update docs to MkDocs Material. * docs: Add cherry.models.md * Add cherry.nn.md * Add some docs. * Finish converting to new docs. * Add docs for nn.init. * Bump version to 0.2.0 * Monkey-patch old wrappers. * Update logo link. * Update logo path. * Fix linting. * Slightly update README.md, and add StateValue abstract class. * Add td3 example on mujoco. * Update offpolicy mujoco example. * Update action-value API. * Add PPO update. * Update Mujoco offpolicy training. * Update algorithms docs. * Algorithm arguments are mappings too. * Reformat. * Add initial DMC example. * Minor clean up. * DrQ on DMC now running. * Add DMC sweeps. * Remove TODO. * Fix wrong error raised in Policy forward. * Clean Makefile, add gym requirements for dev. * Clean up DMC example. * Use log_alpha provided to SAC. * Update a2c, ppo, trpo docs. * Docs for DDPG. * Write TD3 docs. * Update docs for DrQ, SAC, TRPO, DDPG, and A2C. * Write DrQv2 and PPO docs. * Remove RandomShiftsAug from algorithms docs. * Fix non-learnable log_alpha in DrQ. * Detach log_alpha when non-learnable in DrQ. * Update changeleog * Fix linting. * Fix CI. * Change ci95 factor to 1.96 instead of 2.0. * Add DMC sweep results. * Update DMC plots. * Update dmc examples README.md with results. * Remove individual DMC figures. * Update README.md Initial changes to the examples in README. * Update README.md * Update first part of README.md * Add experience replay example. * Update README.md * Write final algorithm example. * Edit algorithm readme example. * Edit readme. * Fix CI. * Fix gym version in CI. * Fix gym version in CI. * Update readme and CI. * CI: update setuptools version * CI: change gym install command. * CI: try another gym version. * Update python_unittest.yaml * Fix more tests. * Fix more tests. * Fixes encoding in setup.py (#42) * Fix wrong linear scaling in ActionSpaceScaler (#46) * Add DrQ and SAC updates. --------- Co-authored-by: 3DAlgoLab <[email protected]> Co-authored-by: gwwo <[email protected]>
- Loading branch information
1 parent
18a1269
commit f4164a5
Showing
133 changed files
with
5,797 additions
and
1,775 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
|
||
name: Testing | ||
|
||
on: [push, pull_request, create] | ||
|
||
jobs: | ||
tests: | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
matrix: | ||
os: [ubuntu-latest, macos-latest] | ||
python: ['3.7', '3.8'] | ||
pytorch: ['1.4.0', '1.5.0', '1.6.0', '1.7.0'] | ||
include: | ||
- pytorch: '1.4.0' | ||
torchvision: '0.5.0' | ||
- pytorch: '1.5.0' | ||
torchvision: '0.6.0' | ||
- pytorch: '1.6.0' | ||
torchvision: '0.7.0' | ||
- pytorch: '1.7.0' | ||
torchvision: '0.8.0' | ||
exclude: | ||
- pytorch: '1.3.0' | ||
python: '3.8' | ||
|
||
steps: | ||
- name: Clone Repository | ||
uses: actions/checkout@v2 | ||
with: | ||
ref: ${{ github.ref }} | ||
- name: Set up Python | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python }} | ||
architecture: x64 | ||
- name: Install Dependencies | ||
run: | | ||
python3 --version | ||
python3 -m pip install -U pip setuptools | ||
# pip3 install --install-option="--no-cython-compile" cython | ||
pip3 install torch==${{ matrix.pytorch }} | ||
pip3 install torchvision==${{ matrix.torchvision }} | ||
pip3 install chardet==3.0.4 # can be remove when fix in: https://github.com/aio-libs/aiohttp/issues/5366 | ||
pip3 install requests numpy gsutil tqdm pygame | ||
pip3 install gym==0.23.1 | ||
make dev | ||
- name: Lint Code | ||
run: make lint | ||
- name: Run Tests | ||
run: make tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,3 +122,6 @@ examples/bsuite/figs/** | |
|
||
# ignore .vscode | ||
.vscode | ||
.all_objects.cache | ||
wandb/** | ||
**/wandb/** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
<p align="center"><img src="http://cherry-rl.net/assets/images/cherry_full.png" height="150px" /></p> | ||
<p align="center"><img src="http://cherry-rl.net/assets/images/cherry_full.png" height="128px" /></p> | ||
|
||
-------------------------------------------------------------------------------- | ||
|
||
|
@@ -13,63 +13,109 @@ So if you don't like a specific tool, you don’t need to use it. | |
|
||
**Features** | ||
|
||
* Pythonic and low-level interface *à la* Pytorch. | ||
* Support for tabular (!) and function approximation algorithms. | ||
* Various OpenAI Gym environment wrappers. | ||
* Helper functions for popular algorithms. (e.g. A2C, DDPG, TRPO, PPO, SAC) | ||
* Logging, visualization, and debugging tools. | ||
* Painless and efficient distributed training on CPUs and GPUs. | ||
* Unit, integration, and regression tested, continuously integrated. | ||
Cherry extends PyTorch with only a handful of new core concepts. | ||
|
||
* PyTorch modules for reinforcement learning: | ||
* [`cherry.nn.Policy`](http://cherry-rl.net/api/cherry.nn/#cherry.nn.policy.Policy): base class for $\pi(a \mid s)$ policies. | ||
* [`cherry.nn.ActionValue`](http://cherry-rl.net/api/cherry.nn/#cherry.nn.action_value.ActionValue): base class for $Q(s, a)$ action-value functions. | ||
* Data structures for reinforcement learning compatible with PyTorch: | ||
* [`cherry.Transition`](http://cherry-rl.net/api/cherry/#cherry.experience_replay.Transition): namedtuple to store $(s_t, a_t, r_t, s_{t+1})$ transitions (and more). | ||
* [`cherry.ExperienceReplay`](http://cherry-rl.net/api/cherry/#cherry.experience_replay.ExperienceReplay): a list-like buffer to store and sample transitions. | ||
* Low-level interface *à la* PyTorch to write and debug your algorithms. | ||
* [`cherry.td.*`](http://cherry-rl.net/api/cherry.td/) and [`cherry.pg.*`](http://cherry-rl.net/api/cherry.pg/): temporal difference and policy gradient utilities. | ||
* [`cherry.algorithms.*`](http://cherry-rl.net/api/cherry.algorithms/): helper functions for popular algorithms ([PPO](http://cherry-rl.net/api/cherry.algorithms/#cherry.algorithms.ppo.PPO), [TD3](http://cherry-rl.net/api/cherry.algorithms/#cherry.algorithms.td3.TD3), [DrQ](http://cherry-rl.net/api/cherry.algorithms/#cherry.algorithms.drq.DrQ), and [more](http://cherry-rl.net/api/cherry.algorithms/)). | ||
* [`cherry.debug.*`](http://cherry-rl.net/api/cherry.debug/) and [`cherry.plot.*`](http://cherry-rl.net/api/cherry.plot/): logging, visualization, and debugging tools. | ||
|
||
To learn more about the tools and philosophy behind cherry, check out our [Getting Started tutorial](http://cherry-rl.net/tutorials/getting_started/). | ||
|
||
## Example | ||
## Overview and Examples | ||
|
||
The following snippet showcases some of the tools offered by cherry. | ||
The following snippet showcases a few of the tools offered by cherry. | ||
Many more high-quality examples are available in the [examples/](./examples/) folder. | ||
|
||
~~~python | ||
import cherry as ch | ||
#### Defining a [`cherry.nn.Policy`](http://cherry-rl.net/api/cherry.nn/#cherry.nn.policy.Policy) | ||
|
||
# Wrap environments | ||
env = gym.make('CartPole-v0') | ||
env = ch.envs.Logger(env, interval=1000) | ||
env = ch.envs.Torch(env) | ||
~~~python | ||
class VisionPolicy(cherry.nn.Policy): # inherits from torch.nn.Module | ||
|
||
def __init__(self, feature_extractor, actor): | ||
super(VisionGaussianPolicy, self).__init__() | ||
self.feature_extractor = feature_extractor | ||
self.actor = actor | ||
|
||
def forward(self, obs): | ||
mean = self.actor(self.feature_extractor(obs)) | ||
std = 0.1 * torch.ones_like(mean) | ||
return cherry.distributions.TanhNormal(mean, std) # policies always return a distribution | ||
|
||
policy = VisionPolicy(MyResnetExtractor(), MyMLPActor()) | ||
action = policy.act(obs) # sampled from policy's distribution | ||
deterministic_action = policy.act(obs, deterministic=True) # distribution's mode | ||
action_distribution = policy(obs) # work with the policy's distribution | ||
~~~ | ||
|
||
policy = PolicyNet() | ||
optimizer = optim.Adam(policy.parameters(), lr=1e-2) | ||
replay = ch.ExperienceReplay() # Manage transitions | ||
#### Building a [`cherry.ExperienceReplay`](http://cherry-rl.net/api/cherry/#cherry.experience_replay.ExperienceReplay) of [`cherry.Transition`](http://cherry-rl.net/api/cherry/#cherry.experience_replay.Transition) | ||
|
||
for step in range(1000): | ||
state = env.reset() | ||
while True: | ||
mass = Categorical(policy(state)) | ||
action = mass.sample() | ||
log_prob = mass.log_prob(action) | ||
next_state, reward, done, _ = env.step(action) | ||
|
||
# Build the ExperienceReplay | ||
replay.append(state, action, reward, next_state, done, log_prob=log_prob) | ||
if done: | ||
break | ||
else: | ||
state = next_state | ||
|
||
# Discounting and normalizing rewards | ||
rewards = ch.td.discount(0.99, replay.reward(), replay.done()) | ||
rewards = ch.normalize(rewards) | ||
|
||
loss = -th.sum(replay.log_prob() * rewards) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
replay.empty() | ||
~~~python | ||
# building the replay | ||
replay = cherry.ExperienceReplay() | ||
state = env.reset() | ||
for t in range(1000): | ||
action = policy.act(state) | ||
next_state, reward, done, info = env.step(action) | ||
replay.append(state, action, reward, next_state, done) | ||
next_state = state | ||
|
||
# manipulating the replay | ||
replay = replay[-256:] # indexes like a list | ||
batch = replay.sample(32, contiguous=True) # sample transitions into a replay | ||
batch = batch.to('cuda') # move replay to device | ||
for transition in reversed(batch): # iterate over a replay | ||
transition.reward *= 0.99 | ||
|
||
# get all states, actions, and rewards as PyTorch tensors. | ||
reinforce_loss = - torch.sum(policy(batch.state()).log_prob(batch.action()) * batch.reward()) | ||
~~~ | ||
|
||
Many more high-quality examples are available in the [examples/](./examples/) folder. | ||
#### Designing algorithms with [`cherry.td`](http://cherry-rl.net/api/cherry.td/), [`cherry.pg`](http://cherry-rl.net/api/cherry.pg/), and [`cherry.algorithms`](http://cherry-rl.net/api/cherry.algorithms/) | ||
|
||
## Installation | ||
~~~python | ||
# defining a new algorithm | ||
@dataclasses.dataclass | ||
class MyA2C: | ||
discount: float = 0.99 | ||
|
||
def update(self, replay, policy, state_value, optimizer): | ||
# discount rewards | ||
values = state_value(replay.action()) | ||
discounted_rewards = cherry.td.discount( | ||
self.discount, replay.reward(), replay.done(), bootstrap=values[-1].detach() | ||
) | ||
|
||
# Compute losses | ||
policy_loss = cherry.algorithms.A2C.policy_loss( | ||
log_probs=policy(replay.state()).log_prob(replay.action()), | ||
advantages=discounted_rewards - values.detach(), | ||
) | ||
value_loss = cherry.algorithms.A2C.state_value_loss(values, discounted_rewards) | ||
|
||
# Optimization step | ||
optimizer.zero_grad() | ||
(policy_loss + value_loss).backward() | ||
optimizer.step() | ||
return {'a2c/policy_loss': policy_loss, 'a2c/value_loss': value_loss} | ||
|
||
# using MyA2C | ||
my_a2c = MyA2C(discount=0.95) | ||
my_policy = MyPolicy() | ||
linear_value = cherry.models.LinearValue(128) | ||
adam = torch.optim.Adam(policy.parameters()) | ||
for step in range(1000): | ||
replay = collect_experience(policy) | ||
my_a2c.update(replay, my_policy, linear_value, adam) | ||
~~~ | ||
|
||
**Note** Cherry is considered in early alpha release. Stuff might break. | ||
## Install | ||
|
||
``` | ||
pip install cherry-rl | ||
|
@@ -85,7 +131,6 @@ Documentation and tutorials are available on cherry’s website: [http://cherry- | |
|
||
## Contributing | ||
|
||
First, thanks for your consideration in contributing to cherry. | ||
Here are a couple of guidelines we strive to follow. | ||
|
||
* It's always a good idea to open an issue first, where we can discuss how to best proceed. | ||
|
@@ -96,9 +141,6 @@ Here are a couple of guidelines we strive to follow. | |
* it shows users how to use your functionality, and | ||
* it gives a concrete example when discussing the best way to merge your implementation. | ||
|
||
We don't have forums, but are happy to discuss with you on slack. | ||
Make sure to send an email to [[email protected]](mailto:[email protected]) to get an invite. | ||
|
||
## Acknowledgements | ||
|
||
Cherry draws inspiration from many reinforcement learning implementations, including | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.