Latent Models and Variational Inference
Last updated
Last updated
Latent variable model is a subclass of probabilistic models in machine learning. One simple motivating example:
As a quick review, entropy is a measure for how random a distribution is, or how likely is its log probability in expectation under itself, defined as:
And KL-Divergence measures the difference between two distributions, or how small is the log probability of one distribution in expectation under another, minus entropy:
All included screenshots credit to Lecture 13 (fall18 recording)
Say if we want to model only the red data points, call it , a Gaussian distribution is likely good enough: just find a mean and variance that best fit the red cluster. Formally, we want a probabilistic model , and . However, if we want to instead model _all three colors' data, they are clearly individually Gaussian under a color-dependent condition, but a single Gaussian can't fit the overall distribution very well. If we want to still use Gaussian to model the clusters, we can make it a "2-step process": let be a variable following some color distribution (it's 'latent' because we can't directly obtain it from the data), then x conditioned on z can be a Gaussian whose parameters (mean and std) are decided by . Now, if we want to know just , we can sum over conditional probabilities under all three colors.
Thus the motivation for latent variable models: when it's difficult to use one single distribution to fit data x, we can opt for modeling a simple latent variable z, then use z to fit a simple conditional distribution on x. In deep learning we use neural nets so that, although x|z and z are simple in distribution, how to use z to parametrize x is complex. Now we can write the probability of a single x as an integration of the product , over all possible "source" 's.
Now we can look at another, more RL-related example: in model-based RL, when we only get observations o and don't have full information about the state (below denoted by x, not s), we can view it as modeling with the latent variable and observability . And since our action/control (denoted by u here) has direct impact on states x, we'd also like to model transitions between 's, thus requiring structure in the latent space.
Now let's discuss how exactly can we train latent variable models like those above, when we only have data from . One big issue is that to obtain we need an integration over . The idea behind variational inference is this: ultimately we want to model p(x) by maximizing on each datapoint but use latent variables we can start with , approximate it with , use it to write a lower bound for , and maximize this lower bound instead.
With Jesen's Inequality (moves the log inside expectation) and nice properties of log functions, we can take any distribution q(x), and derive a lower bound for :
Now with the two above definitions, we can 1. rewrite the last term in the lower bound above as entropy of (the first equation below), and 2. also write with an arbitrary q and the KL-divergence between q and (second equation below). Since KL-divergence is nonnegative, minimizing it will bring closer to its lower bound . So finally, we have a clearer objective here: maximizing L(p,q) w.r.t. q, this will both increase and minimizes the divergence between and . Note the subscript : each is a datapoint we have, and fitting the model means we are point-wise fitting these points.
Now that we have a new objective, the standard variational inference method works as shown below. It learns the parameters for the conditional distribution by doing gradient descent on the approximate gradient. Because is a conditional probability, we need to first sample a from , but because is also a conditional probability we need to use to sample from . This is why the "sample z" line can be a little confusing: because q is conditioned, we have a different distribution for each datapoint . And this is why, to also learn the proper , we also need the to gradient update each 's parameters. In practice, what we can do is assuming Gaussian and parametrizing each with a mean and a variance, so we can do gradient ascent on those two values instead.
However, a big problem with the above method is that it requires a pair of mean and std for each datapoint , which means the model will scale monstrously. Amortized variational inference tries to solve this problem by learning a neural network for instead, so now we have a second set of parameters to be learned. Instead of , we now sample 's by inputting to this neural network. Now, to actually train and , there's one more catch called the reparametrization trick, and it's meant to reduce the variance in gradients of . We can actually draw some equivalency between the variational lower bound and policy gradient:
What reparametrization means is, we don't directly sample a value for , but instead obtain a mean and std from the network, and use a standard normal variable to calculate . This alternative gradient estimation is much lower in variance.
First, let's look at more math that gives yet another view of our objective . After reparameterization, we can write a new expectation as over the standard normal variable , and approximate it with . So doing gradient ascent on this objective really is both maximizing the network's ability to recover data (using network), and to approximate the actual prior with network.
Name the above discussed network encoder, and network decoder, we basically have a VAE that's capable of generating new data. If the training data points are images denoted by , when the training is done, we can actually generate a new image very similar to our training data with a 2-step sampling: first feed some random image and sample from q to parametrize and obtain a , then use that to get an from the learned decoder