originally posted on giuliostarace.com/posts/dlml-tutorial/ (recommended for better math rendering)
UPDATE (2024-08-10): I have added a bit more detail in The How section regarding how we implement the edge cases. From "Let's go line by line" to "in lines 7 through 9".
In this post I will explain what the Discretized Logistic Mixture Likelihood (DLML)1 is. This is a modeling method particularly relevant to my MSc thesis, where I use it to model the continuous action my imitation learning agent should make. While there are already some great posts explaining the concept, the information is scattered, which can make understanding the concept a bit painful. I will first start with motivating why we need DLML. I will then present what DLML is. Finally I will outline how we can implement DLML in PyTorch.
Suppose you wish to predict some variable that happens to be continuous, conditional on some other quantity. For example, you are interested in predicting the value of a given pixel in an image, given the values of neighbouring pixels.
With a bit of domain knowledge, this class of problem can typically be reformulated by discretizing the target variable and modeling the resulting (conditional) probability distribution. The prediction task can then be posed as a classification one over the discretized bins: apply a softmax and train using cross-entropy loss.
This is (in part) what the authors of PixelCNN did: for the task of conditional image generation, they model each (sub)pixel of an image with a softmax over a 256-dimensional vector, where each dimension represents an 8-bit intensity value that the pixel may take. There are more details, particularly around the conditioning, but that's all you need to know for now for the premise of this tutorial.
Immediately, we face a number of limitations:
-
Softmax can be computationally expensive and unstable. This is particularly problematic for high-dimensional inputs, which are usually the case when dealing with (discretized) continuous variables. This is particularly problematic if you plan to repeat the computation on several output variables (e.g. several pixels of an output image, several dimensions of robotic arm rotation, etc.), which is usually the case when interested in conditional generation of continuous variables.
-
Softmax can lead to sparse gradients, especially at the beginning of training, which can slow down learning. This is also especially the case with high-dimensional input.
-
Softmax does not model any sort of ordinality in the random variable that is being considered: every single dimension in the input vector is considered independently. There is no notion that a value of 127 is close to 128. This ordinality is typically present when dealing with (discretized) continuous variables by virtue of their nature. Rather than relying on some inductive bias, the model has to devote more training time to learn this aspect of the data, leading to slower training.
-
Softmax fails to properly model values that are never observed, assigning probabilities of 0 to values that may otherwise be more likely to occur.
These issues are at least some of the motivations for using DLML, which I will introduce in the next section.
In DLML, for a given output variable
-
We assume that there is a latent value
$v$ with a continuous distribution. -
We take
$y$ to come from a discretization of this continuous distribution of$v$ . We do this discretization in some arbitrary way, but usually by rounding to the nearest 8-bit representation. What this means is that if e.g.$v$ can be any value between 0 and 255, then$y$ will be any integer between those two numbers. -
We model
$v$ using a simple continuous distribution - e.g. the logistic distribution. -
We then take a further step, choosing to model
$v$ as a mixture of$K$ logistic distributions:$$ v \sim \sum_i^K \pi_i \text{logistic}(\mu_i, s_i), $$
(equation 1) where
$\pi_i$ is some coefficient weighing the likelihood of the $i$th distribution while$\mu_i$ and$s_i$ are the mean and scale parametrizing it. -
To compute the likelihood of
$y$ , we sum its (weighted) probability masses over the$K$ mixtures. We can obtain the probability masses by computing the difference between consecutive cumulative density function (CDF) values of equation (1). Note that the CDF of the logistic distribution is a sigmoid function. We therefore write:$$ p(y | \mathbf{\pi}, \mathbf{\mu}, \mathbf{s} ) = \sum_{i=1}^K \pi_i \left[\sigma\left(\frac{y + 0.5 - \mu_i}{s_i}\right) - \sigma\left(\frac{y - 0.5 - \mu_i}{s_i}\right)\right], $$
(equation 2) where
$\sigma$ is the logistic sigmoid. The 0.5 value comes from the fact that we have discretized$v$ into$y$ through rounding, and therefore successive values of our discrete random variable$y$ are found at this boundary. -
We can additionally model edge cases, replacing
$y - 0.5$ with$-\infty$ when$y=0$ and$y + 0.5$ with$+\infty$ when$y = 2^8 = 255$ .
This is nothing more than a likelihood, so we can use it in a maximum likelihood estimation (MLE) process to estimate our parameters. In the case of Deep Learning, we use the negative log likelihood as our loss function. This comment on GitHub provides a different perspective to what's going on.
This approach provides a number of advantages, many of which address the shortcomings of the softmax approach described in the previous section. In particular:
-
It avoids assigning probability mass outside the valid range of [0, 255] by explicitly modeling the rounding and edge cases.
-
Edge values are naturally assigned higher probability values, which tends to align with what is observed when dealing with this nature of data.
-
We rely on the simple sigmoid function, which is less computationally expensive than its multi-class cousin the softmax. This addresses limitation 1 from above.
-
Because we are now making use of the logistic distribution to model the (latent) value of
$y$ , we are implicitly also modelling ordinality when discretizing, since the logistic distribution is continuous. This addresses limitation 3 from above. -
Our reliance on a continuous distribution similarly addresses limitation 4 from above, as we will no longer assign non-zero probability prematurely.
-
Empirically it has been found that only a small number of mixtures,
$\le 10$ , is enough. What this means is that we can work with much lower dimensionality network outputs (3 parameters:$\mu$ ,$s$ and$\pi$ for each mixture element), leading to denser gradients. This addresses limitation 2 from above. -
Because we make use of a mixture, we can more easily model multi-modal data. This can be desirable when learning skills from imitation, where the same skill can be shown to be completed through different action sequences. It is exactly for this reason that Lynch et al. 2020 and Mees et al. 2022 make use of DLML in their action decoders.
So how do we actually go about implementing this? This is one of those techniques where we do slightly different things depending on whether we are training or whether we just want outputs from our model (sampling). For completeness, I provide a full-reference to the code below on my GitHub.
Earlier, we defined a likelihood for
For a given output variable
# each of these have shape (B x K)
means, log_scales, mixture_logits = model(**batch['inputs'])
inv_scales = torch.exp(-log_scales)
We treat the predicted scales as log_scales
to obtain inv_scales
i.e.
We can then start computing the rest of the terms for our likelihood from equation (2).
y = batch['targets']
# explained in text
epsilon = (0.5*y_range) / (num_y_vals - 1)
# convenience variable
centered_y = y.unsqueeze(-1).repeat(1, 1, means.shape[-1]) - means
# inputs to our sigmoid functions
upper_bound_in = inv_scales * (centered_y + epsilon)
lower_bound_in = inv_scales * (centered_y - epsilon)
# remember: cdf of logistic distr is sigmoid of above input format
upper_cdf = torch.sigmoid(upper_bound_in)
lower_cdf = torch.sigmoid(lower_bound_in)
# finally, the probability mass and equivalent log prob
prob_mass = upper_cdf - lower_cdf
vanilla_log_prob = torch.log(torch.clamp(prob_mass, min=1e-12))
Before I go on, you may be asking - "what is this epsilon? Weren't we
adding/subtracting y_range
is simply num_classes
for y
's shape a little bit when computing centered_y
:
remember, for each target variable, we have multiple means (one for each mixture
component), while we have a single target value. We therefore need to repeat the
target value for each mixture component to complete the subtraction. We now move
on to the edge cases described in step 6 of
the what.
# edges
# log probability for edge case of 0 (before scaling)
low_bound_log_prob = upper_bound_in - F.softplus(upper_bound_in)
# log probability for edge case of 255 (before scaling)
upp_bound_log_prob = - F.softplus(lower_bound_in)
# middle
mid_in = inv_scales * centered_y
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
log_prob_mid = log_pdf_mid - np.log((num_classes - 1) / 2)
"Woah, what is going on here?" you may ask. Let's go line by line.
In line 3, we are defining the log probability for the edge case where we have
upper_bound_in
.
Recall that the CDF of the logistic distribution is the sigmoid function:
Recall that we are interested in log probabilities:
Finally, we can leverage the softplus function
which is the expression we use in the code, with upper_bound_in
as
Let's now move on to line 5, where we compute the log probability for the edge
case where we have lower_bound_in
in this case.
In other words, we are interested in
which we can take the logarithm of and then manipulate in terms of
where in the final line we used the
identity
lower_bound_in
as
Finally, in lines 7 through 9, we also approximate the log probability at the center of the bin, based on the assumption that the log-density is constant in the bin of the observed value. This is used as a backup in cases where calculated probabilities are below 1e-5, which could happen due to numerical instability. This case is extremely rare and I would not dedicate too much thought to it, it is just there as a (rarely-used) backup.
We can now put all these terms together into a single log likelihood tensor:
# Create a tensor with the same shape as 'y', filled with zeros
log_probs = torch.zeros_like(y)
# conditions for filling in tensor
is_near_min = y < output_min_bound + 1e-3
is_near_max = y > output_max_bound - 1e-3
is_prob_mass_sufficient = prob_mass > 1e-5
# And then fill it in accordingly
# lower edge
log_probs[is_near_min] = low_bound_log_prob[is_near_min]
# upper edge
log_probs[is_near_max] = upp_bound_log_prob[is_near_max]
# vanilla case
log_probs[
~is_near_min & ~is_near_max & is_prob_mass_sufficient
] = vanilla_log_prob[
~is_near_min & ~is_near_max & is_prob_mass_sufficient
]
# extreme case where prob mass is too small
log_probs[
~is_near_min & ~is_near_max & ~is_prob_mass_sufficient
] = log_prob_mid[
~is_near_min & ~is_near_max & ~is_prob_mass_sufficient
]
We are almost done, but there is one last piece. So far we have computed the
terms to minimize for learning the distribution(s), i.e. learning
# modeling which mixture to sample from
log_probs = log_probs + F.log_softmax(mixture_logits, dim=-1)
We add a log of the softmax because we are after
All that's left to do now is summing over our mixtures. We do this after applying the Log-Sum-Exp trick for numerical stability
log_likelihood = torch.sum(log_sum_exp(log_probs), dim=-1)
Our loss is the negative log likelihood, which we can choose to reduce across the batch or return unreduced
loss = - log_likelihood
if reduction == 'mean'
loss = torch.mean(loss)
elif reduction =='sum'
loss = torch.sum(loss)
And that's it for training. Once you have your loss, you can run loss.backwards() and all the cool stuff torch provides with autodiff.
Sampling is fortunately a bit easier, and some people start their explanation from here. Here, we first sample a distribution from our mixture, and then sample a value from the sampled distribution. We have logits for each distribution in our mixture, so we can sample from a softmax over this distribution. In practice, we make use of the Gumbel-Max trick, to keep things differentiable.
# each of these have shape (B x K)
means, log_scales, mixture_logits = model(**batch['inputs'])
# gumbel-max sampling
r1, r2 = 1e-5, 1.0 - 1e-5
temp = (r1 - r2) * torch.rand(means.shape, device=means.device) + r2
temp = mixture_logits - torch.log(-torch.log(temp))
argmax = torch.argmax(temp, -1)
argmax
is the index of our sampled distribution. We can use it to get the
distribution's mean and scale from our model's outputs:
# (K dimensional vector, e.g. [0 0 0 1 0 0 0 0] for k=8, argmax=3
dist_one_hot = torch.eye(k)[argmax]
# use it to sample, and aggregate over the batch
sampled_log_scale = (dist_one_hot * log_scales).sum(dim=-1)
sampled_mean = (dist_one_hot * means).sum(dim=-1)
We can then sample from our logistic distribution using inverse sampling. For a logistic distribution, this consists in
(equation 4) where we select y from a random uniform distribution. In code:
# scale the (0,1) uniform distribution and re-center it
y = (r1 - r2) * torch.rand(sampled_mean.shape, device=sampled_mean.device) + r2
sampled_output = sampled_mean + torch.exp(sampled_log_scale) * (
torch.log(y) - torch.log(1 - y)
)
And just like that, we have a way of sampling from our model.
I hope this post was helpful. I came across DLML during my MSc thesis on language-enabled imitation learning and while there are several high quality posts online, I couldn't find a single one that summarized the process in its entirety, from motivation to implementation, so I decided to write it myself, also as a way to help me understand the concept. As a reminder, the complete code accompanying this post is available on my GitHub profile.
- Great comment on Tacotron GitHub
- Somewhat outdated Google Colab
- PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications, Salimans et al., 2017.
- Learning Latent Plans from Play, Lynch et al., 2020.
- What Matters in Language Conditioned Robotic Imitation Learning over Unstructured Data, Mees et al., 2022.
Footnotes
-
Also known as "Discretized Mixture of Logistics (DMoL)", "Discretized Logistic Mixture (DLM)", "Mixture of Discretized Logistics (MDL)" ↩