-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cdd25f1
commit 1244ed7
Showing
2 changed files
with
156 additions
and
0 deletions.
There are no files selected for viewing
74 changes: 74 additions & 0 deletions
74
...-09-CURL Contrastive Unsupervised Representations for Reinforcement Learning.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
--- | ||
layout: post | ||
title: CURL - Contrastive Unsupervised Representations for Reinforcement Learning | ||
comments: True | ||
excerpt: | ||
tags: ['2020', 'Contrastive Learning', 'Deep Reinforcement Learning', 'Reinforcement Learning', 'Self Supervised', 'Sample Efficient', AI, Contrastive, DRL, RL, Unsupervised] | ||
|
||
--- | ||
|
||
## Introduction | ||
|
||
* The paper proposes a contrastive learning approach, called CURL, for performing off-policy control from raw pixel observations (by transforming them into high dimensional features). | ||
|
||
* The idea is motivated by the application of contrastive losses in computer vision. But there are additional challenges: | ||
|
||
* The learning agent has to perform both unsupervised and reinforcement learning. | ||
|
||
* The "dataset" for unsupervised learning is not fixed and keeps changing with the policy of the agent. | ||
|
||
* Unlike prior work, CURL introduces fewer changes in the underlying RL pipeline and provides more significant sample efficiency gains. For example, CURL (trained on pixels) nearly matches the performance of SAC policy (trained on state-based features). | ||
|
||
* [Link to the paper](https://github.com/MishaLaskin/curl) | ||
|
||
## Implementation | ||
|
||
* CURL uses instance discrimination. Deep RL algorithms commonly use a stack of temporally consecutive frames as input to the policy. In such cases, instance discrimination is applied to all the images in the stack. | ||
|
||
* For generating the positive and negative samples, random crop data augmentation is used. | ||
|
||
* Bilinear inner product is used as the similarity metric as it outperforms the commonly used normalized dot product. | ||
|
||
* For encoding the anchors and the samples, InfoNCE is used. It learns two encoders $f_q$ and $f_k$ that transform the query (base input) and the key (positive/negative samples) into latent representations. The similarity loss is applied to these latents. | ||
|
||
* Momentum contrast is used to update the parameters ($\theta_k$) of the $f_k$ network. ie $\theta_k = m \theta_k + (1-m) \theta_q$. $\theta_q$ are the parameters of the $f_q$ network and are updated in the usual way, using both the contrastive loss and the RL loss. | ||
|
||
## Experiment | ||
|
||
* DMControl100K and Atart100K refer to the setups where the agent is trained for 100K steps on DMControl and Atari, respectively. | ||
|
||
* Metrics: | ||
|
||
* Sample Efficiency - How many steps does the baseline need to match CURL's performance after 100K steps. | ||
|
||
* Performance - Ratio of episodic returns by CURL vs. the baseline after 100K steps. | ||
|
||
* Baselines: | ||
|
||
* DMControl | ||
|
||
* [SAC-AE](https://arxiv.org/abs/1910.01741) | ||
* [SLAC](https://arxiv.org/abs/1907.00953) | ||
* [PlaNet](https://planetrl.github.io/) | ||
* [Dreamer](https://openreview.net/forum?id=S1lOTC4tDS) | ||
* [Pixel SAC](https://arxiv.org/abs/1812.05905) | ||
* SAC trained on state-space observations | ||
|
||
* Atari | ||
|
||
* [SimPLe](https://arxiv.org/abs/1903.00374) | ||
* [RainbowDQN](https://arxiv.org/abs/1710.02298) | ||
* [OTRainbow (Over Trained Rainbow)](https://openreview.net/forum?id=Bke9u1HFwB) | ||
* [Efficient Rainbow](https://arxiv.org/abs/1906.05243) | ||
* Random Agent | ||
* Human Performance | ||
|
||
* Results | ||
|
||
* DM Control | ||
|
||
* CURL outperforms all pixel-based RL algorithms by a significant margin for all environments on DMControl and most environments on Atari. | ||
|
||
* On DMControl, it closely matches the performance of the SAC agent trained on state-space observations. | ||
|
||
* On Atari, it achieves better median human normalizes score (HNS) than the other baselines and close to human efficiency in three environments. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
--- | ||
layout: post | ||
title: Supervised Contrastive Learning | ||
comments: True | ||
excerpt: | ||
tags: ['2020', 'Contrastive Learning', AI, Contrastive, ImageNet] | ||
|
||
--- | ||
|
||
## Introduction | ||
|
||
* The paper builds on the prior work on self-supervised contrastive learning and extends it for the supervised learning case where many positive examples are available for each anchor. | ||
|
||
* [Link to the paper](https://arxiv.org/abs/2004.11362) | ||
|
||
## Approach | ||
|
||
* The representation learning framework has the following components: | ||
|
||
### Data Augmentation Module | ||
|
||
* This module transforms the input example. The paper considers the following strategies: | ||
|
||
* Random crop, followed by resizing | ||
* [Auto Augment](https://arxiv.org/abs/1805.09501) - A method to search for data augmentation strategies. | ||
* [Rand Augment](https://arxiv.org/abs/1909.13719) - Randomly sampling a sequence of data augmentations, with repetition | ||
* SimAugment - Sequentially apply random color distortion and Gaussian blurring, followed by probabilistic sparse image wrap. | ||
|
||
### Encoder Network | ||
|
||
* This module maps the input to a latent representation. | ||
|
||
* The same network is used to encode both the anchor and the sample. | ||
|
||
* The representation vector is normalized to lie on the unit hypersphere. | ||
|
||
### Projection Network | ||
|
||
* This module maps the normalized representation to another representation, on which the contrastive loss is computed. | ||
|
||
* This network is only used for training the supervised contrastive loss. | ||
|
||
### Loss function | ||
|
||
* The paper extends the standard contrastive loss formulation to handle multiple positive examples. | ||
|
||
* The main effect is that the modified loss accounts for all the same-class pairs (from within the sampled batch as well as the augmented batch). | ||
|
||
* The paper shows that the gradient (corresponding to the modified loss) causes the learning to focus more on hard examples. "Hard" cases are the ones where contrasting the anchor benefits the encoder more. | ||
|
||
* The proposed loss can also be seen as a generalization of the triplet loss. | ||
|
||
## Experiments | ||
|
||
* Dataset - ImageNet | ||
|
||
* Models - ResNet50, ResNet200 | ||
|
||
* The network is "pretrained" using supervised contrastive loss. | ||
|
||
* After pre-training, the projection network is removed, and a linear classifier is added. | ||
|
||
* This classifier is trained with the CE loss while the rest of the network is kept fixed. | ||
|
||
## Results | ||
|
||
* Using supervised contrastive loss improves over all the baseline models and data augmentation approaches. | ||
|
||
* The resulting classifier is more robust to image corruptions, as shown by the mean Corruption Error (mCE) metric on the ImageNet-C dataset. | ||
|
||
* The model is more stable to the choice oh hyperparameter values (like optimizers, data augmentation, and learning rates). | ||
|
||
## Training Details | ||
|
||
* Supervised Contrastive loss is trained for 700 epochs during pre-training. | ||
|
||
* Each step is about 50% more expensive than performing CE. | ||
|
||
* The dense classifier layer can be trained in as few as ten epochs. | ||
|
||
* The temperature value is set to 0.07. Using a lower temperature is better than using a higher temperature. | ||
|