Control as Inference
Last updated
Last updated
At a high level, this section introduces a new decision making model that accounts for occasionally suboptimal actions, which is similar to real-world human behaviors. Then use this new model, we can derive optimal control and RL algorithms. As shown below, the new nodes are called optimality variables, each of them gives an "intention", a binary variable that denotes either the action is optimal or not. But note that when we write in conditional probabilities below such as , this means conditioning on all those optimality being true.
Key idea: the probability of a trajectory is proportional to expected rewards.
Below, let's first look at three values we can compute/infer from this model and then use them to formulate inference as an RL problem.
We can think of computing backward messages as inferring what's the probability for onward optimality starting from the current state and action in the trajectory, i.e. probability of all onward optimality variables having value 1. After expanding this definition we can separate this conditional probability to the multiplication of several probabilities, and notice how the onward optimality probability conditioned on only one state can be written as an expectation over action. The math can be a little tricky to wrap your head around, but this derivation enables us to calculate backward messages in a recursive way:
I'm omitting some more math here that basically shows the action prior distribution doesn't affect the formulation because it can always be folded into the reward. So above we assumed uniform action prior without loss of generality.
Defined as probability of a state given up-until-now optimal actions, it can be expanded out, again using chain rule of conditional probability, so that it can be calculated recursively from the beginning state. The first set of long equations below are essentially showing how we can using gathered known quantities to calculate forward message; and using both forward and backward messages, we can actually calculate the probability of a state (i.e. state marginals) under overall optimality, which is proportional to them multiplied together:
To begin with, let's see why and how variational inference, as introduced in the previous lecture/note, can help us recover the optimal policy under the new model discussed above.
Recall how we've been setting optimality variables to all true and put them as given conditions/evidences when we calculate posterior action or state probabilities. But while this "evidence" allows us to calculate the best action under given optimality, it also affects the state transition dynamics:
This makes sense from an inference perspective: given you are under optimal policy, the high-reward next states are more likely to come up; but not for control: what we want is to select the best actions assuming the state transition dynamic is the same. Recall the idea of variational inference, we can learn a model that approximates a posterior distribution.
Notice how this q-distribution is supposed to do two things: it both gives the posterior probability of any trajectory under optimality, and yields the approximate transition dynamics probability if conditioned on a current state-action pair. And to achieve these two, we choose a form for it:
So now we can draw a new transition model as below. Notice how the q distribution preserves the transition dynamics in the original optimality model, but allows us to omit the script-O nodes because it's already conditioned on optimality to the action choices forms an optimal policy.
Furthermore, this soft value iteration has variants that allows discounted expected V's and explicit temperature that weights the V-function towards a hard-max, to control the stochasticity as desired:
I'll stop the notes here for now, but in the rest of this lecture we can see modified RL algorithms with soft optimality added to the original RL objectives, and stochastic models for learning control.
All included screenshots credit to Lecture 14
Additionally, given a state-action pair, we define the probability of optimality to be exponential of its current reward (assuming negative reward so exponential stays between ). Thus we can write the probability of any trajectory under all-time optimality as proportional to a feasibility probability , multiplied by a cumulation of exponential rewards throughout that trajectory. So that if the agent is always operating under optimality, the probability of experiencing a high-reward trajectory is higher.
Model the suboptimal decision making process. We add a new binary node here to represent the agent's "intention", the hidden behavioral logic that largely functions according to expected rewards but also has some stochastic.
And, if we take a closer look at log of the variables, there's an uncanny resemblance to value iteration algorithms we've seen before. But here, the value function in a log expectation over all possible current actions. Taking log of a sum (integration here is roughly a sum) of exponentials will let the big Q values dominate, so unlike value iteration, only approaches the max Q value of some action at the current state.
Another nuance to notice here is that, in the new definition of Q value here, because we are taking into account optimality, transition dynamics makes a difference: if the state transition is deterministic, then there's only one possible next state given the current state-action, then the log expectation term equals to of next state; but if the transition is stochastic and since we are assuming optimality everywhere, the log expectation term will be biased towards states that give higher rewards.
Under this new model with optimality nodes added, now under a policy, the probability of an action is defined conditioned on both state and O. And for optimal policy at each timestep, we can assume all previous and onward optimality variables are all true, then calculate probability for current action. After a lot of Baye's rule derivations, we can actually conveniently write this policy as a ratio of backward messages, or exponential of "Advantage" value with the new and definitions:
Now we are ready solve this reformed problem with maximizing the variational lower bound on prob(all-time optimality). And with more derivations skipped, we can actually show this lower bound resembles a lot like a reinforcement learning objective, only with an entropy term on q added, and the backward pass to dynamically calculate the backward messages, as discussed above, can be seen as a soft value iteration that no longer has the optimistic problem and takes a of Q values.