-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
76733f7
commit 7605295
Showing
2 changed files
with
65 additions
and
1 deletion.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
...03-12-Competitive Training of Mixtures of Independent Deep Generative Models.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
--- | ||
layout: post | ||
title: Competitive Training of Mixtures of Independent Deep Generative Models | ||
comments: True | ||
excerpt: | ||
tags: ['2018', 'Causal Learning', 'Clustering', 'Generative Models', AI, Causality] | ||
|
||
--- | ||
|
||
## Introduction | ||
|
||
* The paper proposes a Competitive training mechanism to train a mixture of independent generative models. | ||
|
||
* The idea is that this mixture of different models would divide the data distribution amongst themselves and specialize to their respective splits. | ||
|
||
* The training procedure is related to clustering-based methods. | ||
|
||
* [Link to the paper](https://arxiv.org/abs/1804.11130) | ||
|
||
## Motivation | ||
|
||
* In causal modeling, a common assumption is that the data is generated by a set of independent mechanisms. | ||
|
||
* It is not known which mechanism generates which datapoint and recovering the underlying mechanisms can be modeled as learning a structural causal generative model. | ||
|
||
## Setup | ||
|
||
* The paper assumes that the support of the different generators do not overlap, i.e., the underlying data distribution is factorized into non-overlapping regions. | ||
|
||
* This data factorization is learned using a set of discriminators. | ||
|
||
* If there are $k$ generators, $k$ binary partition functions $c_i, ... c_k$ are used. | ||
|
||
* For a given datapoint $x$, if $c_i(x) = 1$ then $c_j(x) = 0$ for all other $j$ and $x$ is assigned to $i^{th}$ generator. | ||
|
||
* For a fixed partition function $c_j^t$ ($t$ denotes the partition function at time $t$), minimize the sum of f-divergence between the model and the data distribution (that is assigned to it). The loss formulation is an upper bound on the f-divergence of the mixture model. | ||
|
||
* In the next step, the data points are re-assigned to the generative models, based on the likelihood of each data point for each model. | ||
|
||
* The likelihood is estimated by training a discriminator that can distinguish the generated samples from the real samples. | ||
|
||
### Independence as an inductive bias | ||
|
||
* The independence assumption may be too restrictive because the low-level features will be common across the distribution splits. | ||
|
||
* This "violation" can be avoided by pretraining the model using a uniform random split of the dataset. In that case, the independence assumption will hold approximately after pretraining. | ||
|
||
* Another approach could be to share some parameters across the models. | ||
|
||
* A "load balancing" approach is also used where each model always keeps training on the data points assigned to it if not enough data points are assigned to it. | ||
|
||
### Comparison to VAEs and GANs | ||
|
||
* VAEs tend to be "overly inclusive" of the training distribution, i.e., they try to cover the entire support of the distribution. | ||
|
||
* GANs are prone to mode collapse where the model focuses only on one part of the distribution. | ||
|
||
* The proposed method provides a middle ground where the different generative models can focus on different parts of the distribution. | ||
|
||
## Experiments | ||
|
||
* The experiments seem to be limited. The paper shows that their proposed setup improves over the VAE and GAN baselines. | ||
|
||
* For datasets, the paper uses two-dimensional synthetic data, MNIST and CelebA |
Submodule _site
updated
from 6d7b32 to df8644