Latent Models and Variational Inference
Latent Variable Models
Latent variable model is a subclass of probabilistic models in machine learning. One simple motivating example:

Say if we want to model only the red data points, call it xred, a Gaussian distribution is likely good enough: just find a mean and variance that best fit the red cluster. Formally, we want a probabilistic model p(xred), and p∼N(μred,σred). However, if we want to instead model _all three colors' data, they are clearly individually Gaussian under a color-dependent condition, but a single Gaussian can't fit the overall p(xall) distribution very well. If we want to still use Gaussian to model the clusters, we can make it a "2-step process": let z be a variable following some color distribution (it's 'latent' because we can't directly obtain it from the data), then x conditioned on z can be a Gaussian whose parameters (mean and std) are decided by z. Now, if we want to know just p(x), we can sum over conditional probabilities under all three colors.
Thus the motivation for latent variable models: when it's difficult to use one single distribution to fit data x, we can opt for modeling a simple latent variable z, then use z to fit a simple conditional distribution on x. In deep learning we use neural nets so that, although x|z and z are simple in distribution, how to use z to parametrize x is complex. Now we can write the probability of a single x as an integration of the product p(x∣z)∗p(z) , over all possible "source" z's.

Now we can look at another, more RL-related example: in model-based RL, when we only get observations o and don't have full information about the state (below denoted by x, not s), we can view it as modeling p(o)with the latent variable x and observability p(o∣x). And since our action/control (denoted by u here) has direct impact on states x, we'd also like to model transitions between x's, thus requiring structure in the latent space.

Variational Approximation
Now let's discuss how exactly can we train latent variable models like those above, when we only have data from x. One big issue is that to obtain p(x)we need an integration over z. The idea behind variational inference is this: ultimately we want to model p(x) by maximizing p(xi) on each datapoint xi but use latent variables we can start with p(z∣x), approximate it with q(x), use it to write a lower bound for logp(x), and maximize this lower bound instead.

With Jesen's Inequality (moves the log inside expectation) and nice properties of log functions, we can take any distribution q(x), and derive a lower bound for logp(x):

As a quick review, entropy is a measure for how random a distribution is, or how likely is its log probability in expectation under itself, defined as:

And KL-Divergence measures the difference between two distributions, or how small is the log probability of one distribution in expectation under another, minus entropy:

Now with the two above definitions, we can 1. rewrite the last term in the lower bound above as entropy of q (the first equation below), and 2. also write logp(x) with an arbitrary q and the KL-divergence between q and p(z∣x) (second equation below). Since KL-divergence is nonnegative, minimizing it will bring logp(x) closer to its lower bound L(p,q). So finally, we have a clearer objective here: maximizing L(p,q) w.r.t. q, this will both increase logp and minimizes the divergence between q and p(z∣x). Note the subscript i: each xi is a datapoint we have, and fitting the model means we are point-wise fitting these points.

Variational Inference: Standard v.s. Amortized
Now that we have a new objective, the standard variational inference method works as shown below. It learns the parameters θ for the conditional distribution p(x∣z) by doing gradient descent on the approximate gradient. Because pθis a conditional probability, we need to first sample a zi from q, but because q is also a conditional probability p(z∣x), we need to use xi to sample zi from q. This is why the "sample z" line can be a little confusing: because q is conditioned, we have a different distribution qi for each datapoint xi. And this is why, to also learn the proper qi, we also need the to gradient update each qi's parameters. In practice, what we can do is assuming Gaussian and parametrizing each qi with a mean and a variance, so we can do gradient ascent on those two values instead.

However, a big problem with the above method is that it requires a pair of mean and std for each datapoint xi, which means the model will scale monstrously. Amortized variational inference tries to solve this problem by learning a neural network for q instead, so now we have a second set of parameters ϕto be learned. Instead of qi, we now sample z's by inputting xi to this q neural network. Now, to actually train p and q, there's one more catch called the reparametrization trick, and it's meant to reduce the variance in gradients of ϕ. We can actually draw some equivalency between the variational lower bound and policy gradient:

What reparametrization means is, we don't directly sample a value for z, but instead obtain a mean and std from the q network, and use a standard normal variable ϵ to calculate z. This alternative gradient estimation is much lower in variance.

Variational Auto-encoder (VAE)
First, let's look at more math that gives yet another view of our objective L. After reparameterization, we can write a new expectation as over the standard normal variable ϵ, and approximate it with log(p(xi∣sampled z)). So doing gradient ascent on this objective really is both maximizing the network's ability to recover data xi (using pθ network), and to approximate the actual prior p(z) with qϕ network.

Name the above discussed qnetwork encoder, and p network decoder, we basically have a VAE that's capable of generating new data. If the training data points are images denoted by x, when the training is done, we can actually generate a new image x′ very similar to our training data with a 2-step sampling: first feed some random image and sample from q to parametrize and obtain a z, then use that z to get an x′ from the learned decoder p

All included screenshots credit to Lecture 13 (fall18 recording)
Last updated