Latent Models and Variational Inference

Latent Variable Models

Latent variable model is a subclass of probabilistic models in machine learning. One simple motivating example:

Say if we want to model only the red data points, call it xredx_{red}, a Gaussian distribution is likely good enough: just find a mean and variance that best fit the red cluster. Formally, we want a probabilistic model p(xred)p(x_{red}), and pN(μred,σred)p \sim N(\mu_{red}, \sigma_{red}). However, if we want to instead model _all three colors' data, they are clearly individually Gaussian under a color-dependent condition, but a single Gaussian can't fit the overall p(xall)p(x_{all}) distribution very well. If we want to still use Gaussian to model the clusters, we can make it a "2-step process": let zz be a variable following some color distribution (it's 'latent' because we can't directly obtain it from the data), then x conditioned on z can be a Gaussian whose parameters (mean and std) are decided by zz. Now, if we want to know just p(x)p(x), we can sum over conditional probabilities under all three colors.

Thus the motivation for latent variable models: when it's difficult to use one single distribution to fit data x, we can opt for modeling a simple latent variable z, then use z to fit a simple conditional distribution on x. In deep learning we use neural nets so that, although x|z and z are simple in distribution, how to use z to parametrize x is complex. Now we can write the probability of a single x as an integration of the product p(xz)p(z)p(x|z)*p(z) , over all possible "source" zz's.

Now we can look at another, more RL-related example: in model-based RL, when we only get observations o and don't have full information about the state (below denoted by x, not s), we can view it as modeling p(o)p(o)with the latent variable xx and observability p(ox)p(o|x). And since our action/control (denoted by u here) has direct impact on states x, we'd also like to model transitions between xx's, thus requiring structure in the latent space.

Variational Approximation

Now let's discuss how exactly can we train latent variable models like those above, when we only have data from xx. One big issue is that to obtain p(x)p(x)we need an integration over zz. The idea behind variational inference is this: ultimately we want to model p(x) by maximizing p(xi)p(x_i) on each datapoint xix_i but use latent variables we can start with p(zx)p(z|x), approximate it with q(x)q(x), use it to write a lower bound for logp(x)\log p(x), and maximize this lower bound instead.

With Jesen's Inequality (moves the log inside expectation) and nice properties of log functions, we can take any distribution q(x), and derive a lower bound for logp(x)\log p(x):

As a quick review, entropy is a measure for how random a distribution is, or how likely is its log probability in expectation under itself, defined as:

And KL-Divergence measures the difference between two distributions, or how small is the log probability of one distribution in expectation under another, minus entropy:

Now with the two above definitions, we can 1. rewrite the last term in the lower bound above as entropy of qq (the first equation below), and 2. also write logp(x)\log p(x) with an arbitrary q and the KL-divergence between q and p(zx)p(z|x) (second equation below). Since KL-divergence is nonnegative, minimizing it will bring logp(x)\log p(x) closer to its lower bound L(p,q)L(p, q). So finally, we have a clearer objective here: maximizing L(p,q) w.r.t. q, this will both increase logp\log p and minimizes the divergence between qq and p(zx)p(z|x). Note the subscript ii: each xix_i is a datapoint we have, and fitting the model means we are point-wise fitting these points.

Variational Inference: Standard v.s. Amortized

Now that we have a new objective, the standard variational inference method works as shown below. It learns the parameters θ\theta for the conditional distribution p(xz)p(x|z) by doing gradient descent on the approximate gradient. Because pθp_\thetais a conditional probability, we need to first sample a ziz_i from qq, but because qq is also a conditional probability p(zx),p(z|x), we need to use xix_i to sample ziz_i from qq. This is why the "sample z" line can be a little confusing: because q is conditioned, we have a different distribution qiq_i for each datapoint xix_i. And this is why, to also learn the proper qiq_i, we also need the to gradient update each qiq_i's parameters. In practice, what we can do is assuming Gaussian and parametrizing each qiq_i with a mean and a variance, so we can do gradient ascent on those two values instead.

However, a big problem with the above method is that it requires a pair of mean and std for each datapoint xix_i, which means the model will scale monstrously. Amortized variational inference tries to solve this problem by learning a neural network for qq instead, so now we have a second set of parameters ϕ\phito be learned. Instead of qiq_i, we now sample zz's by inputting xix_i to this qq neural network. Now, to actually train pp and qq, there's one more catch called the reparametrization trick, and it's meant to reduce the variance in gradients of ϕ\phi. We can actually draw some equivalency between the variational lower bound and policy gradient:

What reparametrization means is, we don't directly sample a value for zz, but instead obtain a mean and std from the qq network, and use a standard normal variable ϵ\epsilon to calculate zz. This alternative gradient estimation is much lower in variance.

Variational Auto-encoder (VAE)

First, let's look at more math that gives yet another view of our objective LL. After reparameterization, we can write a new expectation as over the standard normal variable ϵ\epsilon, and approximate it with log(p(xisampled z))\log (p(x_i | \text{sampled} \ z)). So doing gradient ascent on this objective really is both maximizing the network's ability to recover data xix_i (using pθp_\theta network), and to approximate the actual prior p(z)p(z) with qϕq_\phi network.

Name the above discussed qqnetwork encoder, and pp network decoder, we basically have a VAE that's capable of generating new data. If the training data points are images denoted by xx, when the training is done, we can actually generate a new image xx' very similar to our training data with a 2-step sampling: first feed some random image and sample from q to parametrize and obtain a zz, then use that zz to get an xx' from the learned decoder pp

All included screenshots credit to Lecture 13 (fall18 recording)

Last updated