Skip to content

Commit

Permalink
Add C-SWM paper
Browse files Browse the repository at this point in the history
  • Loading branch information
shagunsodhani committed Dec 9, 2019
1 parent 8a78469 commit 5c94a7f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ I am trying a new initiative - a-paper-a-week. This repository will hold all tho
## List of papers

* [Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://shagunsodhani.com/papers-I-read/Mastering-Atari,-Go,-Chess-and-Shogi-by-Planning-with-a-Learned-Model)
* [Contrastive Learning of Structured World Models](https://shagunsodhani.com/papers-I-read/Contrastive-Learning-of-Structured-World-Models)
* [Gossip based Actor-Learner Architectures for Deep RL](https://shagunsodhani.com/papers-I-read/Gossip-based-Actor-Learner-Architectures-for-Deep-RL)
* [How to train your MAML](https://shagunsodhani.com/papers-I-read/How-to-train-your-MAML)
* [PHYRE - A New Benchmark for Physical Reasoning](https://shagunsodhani.com/papers-I-read/PHYRE-A-New-Benchmark-for-Physical-Reasoning)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
---
layout: post
title: Contrastive Learning of Structured World Models
comments: True
excerpt:
tags: ['2019', 'Graph Neural Network', 'Object-Oriented Learning', 'Relational Learning', AI, Graph, GNN]

---

## Introduction

* The paper introduces Contrastively-trained Structured World Models (C-SWMs).

* These models use a contrastive approach for learning representations in environments with compositional structure.

* [Link to the paper](https://arxiv.org/abs/1911.12247)

* [Link to the code](https://github.com/tkipf/c-swm).

## Approach

* The training data is in the form of an experience buffer $$B = \{(s_t, a_t, s_{t+1})\}_{t=1}^T$$ of state transition tuples.

* The goal is to learn:

* an encoder $$E$$ that maps the observed states $s_t$ (pixel state observations) to latent state $z_t$.

* a transition model $$T$$ that predicts the dynamics in the hidden state.

* The model defines the enegry of a tuple $$(s_t, a_t, s_{t+1})$$ as $$H = d(z_t + T(z_t, a_t), z_{t+1})$$.

* The model has an inductive bias for modeling the effect of action as translation in the abstract state space.

* An extra hinge-loss term is added: $$max(0, \gamma - d(z^{~}_{t}, z_{t+1}))$$ where $$z^{~}_{t} = E(s^{~}_{t})$$ is a corrputed latent state corresponding to a randomly sampled state $$s^{~}_{t}$$.

## Object-Oriented State Factorization

* The goal is to learn object-oriented representations where each state embedding is structured as a set of objects.

* Assuming the number of object slots to be $$K$$, the latent space, and the action space can be factored into $$K$$ independent latent spaces ($$Z_1 \times ... \times Z_K$$) and action spaces ($$A_1 \times ... \times A_k$$) respectively.

* There are *K* CNN-based object extractors and an MLP-based object encoder.

* The actions are represented as one-hot vectors.

* A fully connected graph is induced over *K* objects (representations) and the transition function is modeled as a Graph Neural Network (GNN) over this graph.

* The transition function produces the change in the latent state representation of each object.

* The factorization can be taken into account in the loss function by summing over the loss corresponding to each object.

## Environments

* Grid World Environments - 2D shapes, 3D blocks

* Atari games - Pong and Space Invaders

* 3-body physics simulation

## Setup

* Random policy is used to collect the training data.

* Evaluation is performed in the latent space (no reconstruction in the pixel space) using ranking metrics. The observations (to compare against) are randomly sampled from the buffer.

* Baselines - auto-encoder based World Models and [Physics as Inverse Graphics model](https://arxiv.org/abs/1905.11169).

## Results

* In the grid-world environments, C-SWM models the latent dynamics almost perfectly.

* Removing either the state factorization or the GNN transition model hurts the performance.

* C-SWM performs well on Atari as well but the results tend to have high variance.

* The optimal values of $K$ should be obtained by hyperparameter tuning.

* For the 3-body physics tasks, both the baselines and proposed models work quite well.

* Interestingly, the paper has a section on limitations:

* The object extractor module can not disambiguate between multiple instances of the same object (in a scene).

* The current formulation of C-SWM can only be used with deterministic environments.

0 comments on commit 5c94a7f

Please sign in to comment.