📇
DeepRL
  • CS285: Deep RL Notes
  • RL Fundamentals
  • Policy Gradient
    • Policy Gradient Basics
    • Actor Critic Algorithms
    • Advanced Policy Gradients
  • Value Based Methods
    • Policy and Value Iteration Algorithms
    • DQN and beyond
  • Model-based Methods
    • Model-based Planning and Model-based Predictive Control
    • Model-based Policy Learning
  • Inference, Control, and Inverse RL
    • Latent Models and Variational Inference
    • Control as Inference
    • Inverse Reinforcement Learning
  • Transfer Learning in RL
    • Transfer and Multi-task Learning
    • Paper Reading Notes
  • Coming soon...
    • Offline RL
    • RL from Pixels
Powered by GitBook
On this page
  • Latent Variable Models
  • Variational Approximation
  • Variational Inference: Standard v.s. Amortized
  • Variational Auto-encoder (VAE)
  1. Inference, Control, and Inverse RL

Latent Models and Variational Inference

PreviousModel-based Policy LearningNextControl as Inference

Last updated 4 years ago

Latent Variable Models

Latent variable model is a subclass of probabilistic models in machine learning. One simple motivating example:

Say if we want to model only the red data points, call it xredx_{red}xred​, a Gaussian distribution is likely good enough: just find a mean and variance that best fit the red cluster. Formally, we want a probabilistic model p(xred)p(x_{red})p(xred​), and p∼N(μred,σred)p \sim N(\mu_{red}, \sigma_{red})p∼N(μred​,σred​). However, if we want to instead model _all three colors' data, they are clearly individually Gaussian under a color-dependent condition, but a single Gaussian can't fit the overall p(xall)p(x_{all})p(xall​) distribution very well. If we want to still use Gaussian to model the clusters, we can make it a "2-step process": let zzz be a variable following some color distribution (it's 'latent' because we can't directly obtain it from the data), then x conditioned on z can be a Gaussian whose parameters (mean and std) are decided by zzz. Now, if we want to know just p(x)p(x)p(x), we can sum over conditional probabilities under all three colors.

Thus the motivation for latent variable models: when it's difficult to use one single distribution to fit data x, we can opt for modeling a simple latent variable z, then use z to fit a simple conditional distribution on x. In deep learning we use neural nets so that, although x|z and z are simple in distribution, how to use z to parametrize x is complex. Now we can write the probability of a single x as an integration of the product p(x∣z)∗p(z)p(x|z)*p(z)p(x∣z)∗p(z) , over all possible "source" zzz's.

Now we can look at another, more RL-related example: in model-based RL, when we only get observations o and don't have full information about the state (below denoted by x, not s), we can view it as modeling p(o)p(o)p(o)with the latent variable xxx and observability p(o∣x)p(o|x)p(o∣x). And since our action/control (denoted by u here) has direct impact on states x, we'd also like to model transitions between xxx's, thus requiring structure in the latent space.

Variational Approximation

Now let's discuss how exactly can we train latent variable models like those above, when we only have data from xxx. One big issue is that to obtain p(x)p(x)p(x)we need an integration over zzz. The idea behind variational inference is this: ultimately we want to model p(x) by maximizing p(xi)p(x_i)p(xi​) on each datapoint xix_ixi​ but use latent variables we can start with p(z∣x)p(z|x)p(z∣x), approximate it with q(x)q(x)q(x), use it to write a lower bound for log⁡p(x)\log p(x)logp(x), and maximize this lower bound instead.

With Jesen's Inequality (moves the log inside expectation) and nice properties of log functions, we can take any distribution q(x), and derive a lower bound for log⁡p(x)\log p(x)logp(x):

As a quick review, entropy is a measure for how random a distribution is, or how likely is its log probability in expectation under itself, defined as:

And KL-Divergence measures the difference between two distributions, or how small is the log probability of one distribution in expectation under another, minus entropy:

Now with the two above definitions, we can 1. rewrite the last term in the lower bound above as entropy of qqq (the first equation below), and 2. also write log⁡p(x)\log p(x)logp(x) with an arbitrary q and the KL-divergence between q and p(z∣x)p(z|x)p(z∣x) (second equation below). Since KL-divergence is nonnegative, minimizing it will bring log⁡p(x)\log p(x)logp(x) closer to its lower bound L(p,q)L(p, q)L(p,q). So finally, we have a clearer objective here: maximizing L(p,q) w.r.t. q, this will both increase log⁡p\log plogp and minimizes the divergence between qqq and p(z∣x)p(z|x)p(z∣x). Note the subscript iii: each xix_ixi​ is a datapoint we have, and fitting the model means we are point-wise fitting these points.

Variational Inference: Standard v.s. Amortized

Now that we have a new objective, the standard variational inference method works as shown below. It learns the parameters θ\thetaθ for the conditional distribution p(x∣z)p(x|z)p(x∣z) by doing gradient descent on the approximate gradient. Because pθp_\thetapθ​is a conditional probability, we need to first sample a ziz_izi​ from qqq, but because qqq is also a conditional probability p(z∣x),p(z|x),p(z∣x), we need to use xix_ixi​ to sample ziz_izi​ from qqq. This is why the "sample z" line can be a little confusing: because q is conditioned, we have a different distribution qiq_iqi​ for each datapoint xix_ixi​. And this is why, to also learn the proper qiq_iqi​, we also need the to gradient update each qiq_iqi​'s parameters. In practice, what we can do is assuming Gaussian and parametrizing each qiq_iqi​ with a mean and a variance, so we can do gradient ascent on those two values instead.

However, a big problem with the above method is that it requires a pair of mean and std for each datapoint xix_ixi​, which means the model will scale monstrously. Amortized variational inference tries to solve this problem by learning a neural network for qqq instead, so now we have a second set of parameters ϕ\phiϕto be learned. Instead of qiq_iqi​, we now sample zzz's by inputting xix_ixi​ to this qqq neural network. Now, to actually train ppp and qqq, there's one more catch called the reparametrization trick, and it's meant to reduce the variance in gradients of ϕ\phiϕ. We can actually draw some equivalency between the variational lower bound and policy gradient:

What reparametrization means is, we don't directly sample a value for zzz, but instead obtain a mean and std from the qqq network, and use a standard normal variable ϵ\epsilonϵ to calculate zzz. This alternative gradient estimation is much lower in variance.

Variational Auto-encoder (VAE)

First, let's look at more math that gives yet another view of our objective LLL. After reparameterization, we can write a new expectation as over the standard normal variable ϵ\epsilonϵ, and approximate it with log⁡(p(xi∣sampled z))\log (p(x_i | \text{sampled} \ z))log(p(xi​∣sampled z)). So doing gradient ascent on this objective really is both maximizing the network's ability to recover data xix_ixi​ (using pθp_\thetapθ​ network), and to approximate the actual prior p(z)p(z)p(z) with qϕq_\phiqϕ​ network.

Name the above discussed qqqnetwork encoder, and ppp network decoder, we basically have a VAE that's capable of generating new data. If the training data points are images denoted by xxx, when the training is done, we can actually generate a new image x′x'x′ very similar to our training data with a 2-step sampling: first feed some random image and sample from q to parametrize and obtain a zzz, then use that zzz to get an x′x'x′ from the learned decoder ppp

All included screenshots credit to Lecture 13 (fall18 recording)