Skip to content

Commit

Permalink
v0.2.0: Policy, Value Functions, and more Algorithms (#39)
Browse files Browse the repository at this point in the history
* Improve support for vectorized environments.

* Remove some hacks.

* Add vectorized experience replay, many small improvements.

* Fix lint.

* Fix runner tests and Makefile.

* Bring Fixes from Jumping (#40)

* Update parameters for NatureFeatures.

* Update TRPO for non-contiguous tensors.

* More fixes.

* Clean fix for change of memory layout in PT 1.6.0

* Add mean and mode to TanhNormal distribution.

* Add device and device_env to Torch wrapper.

* Fix device when flattening a replay.

* Remove pdb from logger.

* ExpReplay: implement getstate and setstate

* Add test for replay.flatten(), faster implementation.

* Fix replay test.

* Update reward normalizer.

* More fixes for RewardNormalizer.

* Reward normalizer handles vectorized envs.

* Accelerate creation of Transitions.

* Cosmetics

* Add wandb to pybullet.

* Add nsteps samplgin to experience replay + tests.

* Add nn.Policy (needs docs) and Normal, Categorical distributions (need docs).

* Add line_search to trpo.

* Add algorithm arguments, update LinearValue lstsq.

* Add device to Transition fields.

* Do not assume default device in Torch wrapper.

* Add cherry.nn.MLP

* Remove gym.wrappers.Monitor.

* Add Closer, make compatible with gym 0.23.0

* Add DrQ and SAC updates.

* Fix unit tests.

* Fix Transition.device for vectorized envs.

* Update algorithms.

* tests: Fix MemorizeDigits registration.

* tests: Fix actor-critic integration indentation.

* Add DrQv2.

* Add initial ActionValue API.

* Fix robotics init, Twin action value.

* Add nn.Lambda.

* Remove std reduction in DrQv2.

* Update docs to MkDocs Material.

* docs: Add cherry.models.md

* Add cherry.nn.md

* Add some docs.

* Finish converting to new docs.

* Add docs for nn.init.

* Bump version to 0.2.0

* Monkey-patch old wrappers.

* Update logo link.

* Update logo path.

* Fix linting.

* Slightly update README.md, and add StateValue abstract class.

* Add td3 example on mujoco.

* Update offpolicy mujoco example.

* Update action-value API.

* Add PPO update.

* Update Mujoco offpolicy training.

* Update algorithms docs.

* Algorithm arguments are mappings too.

* Reformat.

* Add initial DMC example.

* Minor clean up.

* DrQ on DMC now running.

* Add DMC sweeps.

* Remove TODO.

* Fix wrong error raised in Policy forward.

* Clean Makefile, add gym requirements for dev.

* Clean up DMC example.

* Use log_alpha provided to SAC.

* Update a2c, ppo, trpo docs.

* Docs for DDPG.

* Write TD3 docs.

* Update docs for DrQ, SAC, TRPO, DDPG, and A2C.

* Write DrQv2 and PPO docs.

* Remove RandomShiftsAug from algorithms docs.

* Fix non-learnable log_alpha in DrQ.

* Detach log_alpha when non-learnable in DrQ.

* Update changeleog

* Fix linting.

* Fix CI.

* Change ci95 factor to 1.96 instead of 2.0.

* Add DMC sweep results.

* Update DMC plots.

* Update dmc examples README.md with results.

* Remove individual DMC figures.

* Update README.md

Initial changes to the examples in README.

* Update README.md

* Update first part of README.md

* Add experience replay example.

* Update README.md

* Write final algorithm example.

* Edit algorithm readme example.

* Edit readme.

* Fix CI.

* Fix gym version in CI.

* Fix gym version in CI.

* Update readme and CI.

* CI: update setuptools version

* CI: change gym install command.

* CI: try another gym version.

* Update python_unittest.yaml

* Fix more tests.

* Fix more tests.

* Fixes encoding in setup.py (#42)

* Fix wrong linear scaling in ActionSpaceScaler (#46)

* Add DrQ and SAC updates.

---------

Co-authored-by: 3DAlgoLab <[email protected]>
Co-authored-by: gwwo <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2023
1 parent 18a1269 commit f4164a5
Show file tree
Hide file tree
Showing 133 changed files with 5,797 additions and 1,775 deletions.
51 changes: 51 additions & 0 deletions .github/workflows/python_unittest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@

name: Testing

on: [push, pull_request, create]

jobs:
tests:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python: ['3.7', '3.8']
pytorch: ['1.4.0', '1.5.0', '1.6.0', '1.7.0']
include:
- pytorch: '1.4.0'
torchvision: '0.5.0'
- pytorch: '1.5.0'
torchvision: '0.6.0'
- pytorch: '1.6.0'
torchvision: '0.7.0'
- pytorch: '1.7.0'
torchvision: '0.8.0'
exclude:
- pytorch: '1.3.0'
python: '3.8'

steps:
- name: Clone Repository
uses: actions/checkout@v2
with:
ref: ${{ github.ref }}
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Install Dependencies
run: |
python3 --version
python3 -m pip install -U pip setuptools
# pip3 install --install-option="--no-cython-compile" cython
pip3 install torch==${{ matrix.pytorch }}
pip3 install torchvision==${{ matrix.torchvision }}
pip3 install chardet==3.0.4 # can be remove when fix in: https://github.com/aio-libs/aiohttp/issues/5366
pip3 install requests numpy gsutil tqdm pygame
pip3 install gym==0.23.1
make dev
- name: Lint Code
run: make lint
- name: Run Tests
run: make tests
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,6 @@ examples/bsuite/figs/**

# ignore .vscode
.vscode
.all_objects.cache
wandb/**
**/wandb/**
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed


## v0.2.0

### Added

* Introduce cherry.nn.Policy, cherry.nn.ActionValue, and cherry.nn.StateValue.
* Algorithm class utilities for: A2C, PPO, TRPO, DDPG, TD3, SAC, and DrQ/DrQv2.
* DMC examples for SAC, DrQ, and DrQv2.
* N-steps returns sampling in ExperienceReplay.

### Changed

* Discontinue most of cherry.wrappers.

### Fixed

* Fixes return value of StateNormalizer and RewardNormalizer wrappers.
* Requirements to generate docs.


## v0.1.4
Expand Down
93 changes: 11 additions & 82 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,88 +1,13 @@

.PHONY: all tests dist docs

all: sac

# Demo
reinforce:
python examples/reinforce_cartpole.py

ac:
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python examples/actor_critic_cartpole.py

grid:
python examples/actor_critic_gridworld.py

# Atari
dist-a2c:
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python -m torch.distributed.launch \
--nproc_per_node=6 \
examples/atari/dist_a2c_atari.py

a2c:
OMP_NUM_THREADS=4 \
MKL_NUM_THREADS=4 \
python examples/atari/a2c_atari.py

ppoa:
OMP_NUM_THREADS=4 \
MKL_NUM_THREADS=4 \
python examples/atari/ppo_atari.py

bug:
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python examples/atari/debug_atari.py

dqn:
python examples/atari/dqn_atari.py

# PyBullet
dist-ppo:
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python -m torch.distributed.launch \
--nproc_per_node=16 \
examples/pybullet/dist_ppo_pybullet.py

ppo:
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python examples/pybullet/ppo_pybullet.py

sac:
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python examples/pybullet/sac_pybullet.py

tsac:
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python examples/pybullet/delayed_tsac_pybullet.py

# Tabular
tabular-s:
python examples/tabular/sarsa.py

tabular-q:
python examples/tabular/q_learning.py

# bsuite

bsuite:
python examples/bsuite/trpo_v_random.py
.PHONY: *

# Admin
dev:
pip install --progress-bar off torch gym pycodestyle >> log_install.txt
python setup.py develop

lint:
pycodestyle cherry/ --max-line-length=160
pycodestyle --max-line-length=160 --ignore=W605 cherry/

lint-examples:
pycodestyle examples/ --max-line-length=80
Expand All @@ -93,14 +18,18 @@ lint-tests:
tests:
OMP_NUM_THREADS=1 \
MKL_NUM_THREADS=1 \
python -W ignore -m unittest discover -s 'tests' -p '*_tests.py' -v
python -W ignore -m unittest discover -s 'tests/' -p '*_tests.py' -v
make lint

docs:
cd docs && pydocmd build && pydocmd serve
predocs:
cp ./README.md docs/index.md
cp ./CHANGELOG.md docs/changelog.md

docs: predocs
mkdocs serve

docs-deploy:
cd docs && pydocmd gh-deploy
docs-deploy: predocs
mkdocs gh-deploy

# https://dev.to/neshaz/a-tutorial-for-tagging-releases-in-git-147e
release:
Expand Down
142 changes: 92 additions & 50 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<p align="center"><img src="http://cherry-rl.net/assets/images/cherry_full.png" height="150px" /></p>
<p align="center"><img src="http://cherry-rl.net/assets/images/cherry_full.png" height="128px" /></p>

--------------------------------------------------------------------------------

Expand All @@ -13,63 +13,109 @@ So if you don't like a specific tool, you don’t need to use it.

**Features**

* Pythonic and low-level interface *à la* Pytorch.
* Support for tabular (!) and function approximation algorithms.
* Various OpenAI Gym environment wrappers.
* Helper functions for popular algorithms. (e.g. A2C, DDPG, TRPO, PPO, SAC)
* Logging, visualization, and debugging tools.
* Painless and efficient distributed training on CPUs and GPUs.
* Unit, integration, and regression tested, continuously integrated.
Cherry extends PyTorch with only a handful of new core concepts.

* PyTorch modules for reinforcement learning:
* [`cherry.nn.Policy`](http://cherry-rl.net/api/cherry.nn/#cherry.nn.policy.Policy): base class for $\pi(a \mid s)$ policies.
* [`cherry.nn.ActionValue`](http://cherry-rl.net/api/cherry.nn/#cherry.nn.action_value.ActionValue): base class for $Q(s, a)$ action-value functions.
* Data structures for reinforcement learning compatible with PyTorch:
* [`cherry.Transition`](http://cherry-rl.net/api/cherry/#cherry.experience_replay.Transition): namedtuple to store $(s_t, a_t, r_t, s_{t+1})$ transitions (and more).
* [`cherry.ExperienceReplay`](http://cherry-rl.net/api/cherry/#cherry.experience_replay.ExperienceReplay): a list-like buffer to store and sample transitions.
* Low-level interface *à la* PyTorch to write and debug your algorithms.
* [`cherry.td.*`](http://cherry-rl.net/api/cherry.td/) and [`cherry.pg.*`](http://cherry-rl.net/api/cherry.pg/): temporal difference and policy gradient utilities.
* [`cherry.algorithms.*`](http://cherry-rl.net/api/cherry.algorithms/): helper functions for popular algorithms ([PPO](http://cherry-rl.net/api/cherry.algorithms/#cherry.algorithms.ppo.PPO), [TD3](http://cherry-rl.net/api/cherry.algorithms/#cherry.algorithms.td3.TD3), [DrQ](http://cherry-rl.net/api/cherry.algorithms/#cherry.algorithms.drq.DrQ), and [more](http://cherry-rl.net/api/cherry.algorithms/)).
* [`cherry.debug.*`](http://cherry-rl.net/api/cherry.debug/) and [`cherry.plot.*`](http://cherry-rl.net/api/cherry.plot/): logging, visualization, and debugging tools.

To learn more about the tools and philosophy behind cherry, check out our [Getting Started tutorial](http://cherry-rl.net/tutorials/getting_started/).

## Example
## Overview and Examples

The following snippet showcases some of the tools offered by cherry.
The following snippet showcases a few of the tools offered by cherry.
Many more high-quality examples are available in the [examples/](./examples/) folder.

~~~python
import cherry as ch
#### Defining a [`cherry.nn.Policy`](http://cherry-rl.net/api/cherry.nn/#cherry.nn.policy.Policy)

# Wrap environments
env = gym.make('CartPole-v0')
env = ch.envs.Logger(env, interval=1000)
env = ch.envs.Torch(env)
~~~python
class VisionPolicy(cherry.nn.Policy): # inherits from torch.nn.Module

def __init__(self, feature_extractor, actor):
super(VisionGaussianPolicy, self).__init__()
self.feature_extractor = feature_extractor
self.actor = actor

def forward(self, obs):
mean = self.actor(self.feature_extractor(obs))
std = 0.1 * torch.ones_like(mean)
return cherry.distributions.TanhNormal(mean, std) # policies always return a distribution

policy = VisionPolicy(MyResnetExtractor(), MyMLPActor())
action = policy.act(obs) # sampled from policy's distribution
deterministic_action = policy.act(obs, deterministic=True) # distribution's mode
action_distribution = policy(obs) # work with the policy's distribution
~~~

policy = PolicyNet()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
replay = ch.ExperienceReplay() # Manage transitions
#### Building a [`cherry.ExperienceReplay`](http://cherry-rl.net/api/cherry/#cherry.experience_replay.ExperienceReplay) of [`cherry.Transition`](http://cherry-rl.net/api/cherry/#cherry.experience_replay.Transition)

for step in range(1000):
state = env.reset()
while True:
mass = Categorical(policy(state))
action = mass.sample()
log_prob = mass.log_prob(action)
next_state, reward, done, _ = env.step(action)

# Build the ExperienceReplay
replay.append(state, action, reward, next_state, done, log_prob=log_prob)
if done:
break
else:
state = next_state

# Discounting and normalizing rewards
rewards = ch.td.discount(0.99, replay.reward(), replay.done())
rewards = ch.normalize(rewards)

loss = -th.sum(replay.log_prob() * rewards)
optimizer.zero_grad()
loss.backward()
optimizer.step()
replay.empty()
~~~python
# building the replay
replay = cherry.ExperienceReplay()
state = env.reset()
for t in range(1000):
action = policy.act(state)
next_state, reward, done, info = env.step(action)
replay.append(state, action, reward, next_state, done)
next_state = state

# manipulating the replay
replay = replay[-256:] # indexes like a list
batch = replay.sample(32, contiguous=True) # sample transitions into a replay
batch = batch.to('cuda') # move replay to device
for transition in reversed(batch): # iterate over a replay
transition.reward *= 0.99

# get all states, actions, and rewards as PyTorch tensors.
reinforce_loss = - torch.sum(policy(batch.state()).log_prob(batch.action()) * batch.reward())
~~~

Many more high-quality examples are available in the [examples/](./examples/) folder.
#### Designing algorithms with [`cherry.td`](http://cherry-rl.net/api/cherry.td/), [`cherry.pg`](http://cherry-rl.net/api/cherry.pg/), and [`cherry.algorithms`](http://cherry-rl.net/api/cherry.algorithms/)

## Installation
~~~python
# defining a new algorithm
@dataclasses.dataclass
class MyA2C:
discount: float = 0.99

def update(self, replay, policy, state_value, optimizer):
# discount rewards
values = state_value(replay.action())
discounted_rewards = cherry.td.discount(
self.discount, replay.reward(), replay.done(), bootstrap=values[-1].detach()
)

# Compute losses
policy_loss = cherry.algorithms.A2C.policy_loss(
log_probs=policy(replay.state()).log_prob(replay.action()),
advantages=discounted_rewards - values.detach(),
)
value_loss = cherry.algorithms.A2C.state_value_loss(values, discounted_rewards)

# Optimization step
optimizer.zero_grad()
(policy_loss + value_loss).backward()
optimizer.step()
return {'a2c/policy_loss': policy_loss, 'a2c/value_loss': value_loss}

# using MyA2C
my_a2c = MyA2C(discount=0.95)
my_policy = MyPolicy()
linear_value = cherry.models.LinearValue(128)
adam = torch.optim.Adam(policy.parameters())
for step in range(1000):
replay = collect_experience(policy)
my_a2c.update(replay, my_policy, linear_value, adam)
~~~

**Note** Cherry is considered in early alpha release. Stuff might break.
## Install

```
pip install cherry-rl
Expand All @@ -85,7 +131,6 @@ Documentation and tutorials are available on cherry’s website: [http://cherry-

## Contributing

First, thanks for your consideration in contributing to cherry.
Here are a couple of guidelines we strive to follow.

* It's always a good idea to open an issue first, where we can discuss how to best proceed.
Expand All @@ -96,9 +141,6 @@ Here are a couple of guidelines we strive to follow.
* it shows users how to use your functionality, and
* it gives a concrete example when discussing the best way to merge your implementation.

We don't have forums, but are happy to discuss with you on slack.
Make sure to send an email to [[email protected]](mailto:[email protected]) to get an invite.

## Acknowledgements

Cherry draws inspiration from many reinforcement learning implementations, including
Expand Down
1 change: 1 addition & 0 deletions cherry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import td
from . import pg
from . import envs
from . import wrappers
from . import optim
from . import nn
from . import models
Expand Down
Loading

0 comments on commit f4164a5

Please sign in to comment.