Variational inference

In the last chapter, we saw that inference in probabilistic models is often intractable, and we learned about algorithms that provide approximate solutions to the inference problem (e.g. marginal inference) by using subroutines that involve sampling random variables. Most sampling-based inference algorithms are instances of Markov Chain Monte-Carlo (MCMC); two popular MCMC methods are Gibbs sampling and Metropolis-Hastings.

Unfortunately, these sampling-based methods have several important shortcomings.

In this chapter, we are going to look at an alternative approach to approximate inference called the variational family of algorithms.

Inference as optimization

The main idea of variational methods is to cast inference as an optimization problem.

Suppose we are given an intractable probability distribution . Variational techniques will try to solve an optimization problem over a class of tractable distributions in order to find a that is most similar to . We will then query (rather than ) in order to get an approximate solution.

The main differences between sampling and variational techniques are that:

Although sampling methods were historically invented first (in the 1940’s), variational techniques have been steadily gaining popularity and are currently the more widely used inference technique.

The Kullback-Leibler divergence

To formulate inference as an optimization problem, we need to choose an approximating family and an optimization objective . This objective needs to capture the similarity between and ; the field of information theory provides us with a tool for this called the Kullback-Leibler (KL) divergence.

Formally, the KL divergence between two distributions and with discrete support is defined as

In information theory, this function is used to measure differences in information contained within two distributions. The KL divergence has the following properties that make it especially useful in our setting:

These can be proven as an exercise. Note however that , i.e. the KL divergence is not symmetric. This is why we say that it’s a divergence, but not a distance. We will come back to this distinction shortly.

The variational lower bound

How do we perform variational inference with a KL divergence? First, let’s fix a form for . We’ll assume that is a general (discrete, for simplicity) undirected model of the form

where the are the factors and is the normalization constant. This formulation captures virtually all the distributions in which we might want to perform approximate inference, such as marginal distributions of directed models with evidence .

Given this formulation, optimizing directly is not possible because of the potentially intractable normalization constant . In fact, even evaluating is not possible, because we need to evaluate .

Instead, we will work with the following objective, which has the same form as the KL divergence, but only involves the unnormalized probability :

This function is not only tractable, it also has the following important property:

Since , we get by rearranging terms that

Thus, is a lower bound on the partition function . In many cases, has an interesting interpretation. For example, we may be trying to compute the marginal probability of variables given observed data that plays the role of evidence. We assume that is directed. In this case, minimizing amounts to maximizing a lower bound on the log-likelihood of the observed data.

Because of this property, is called the variational lower bound or the evidence lower bound (ELBO); it often written in the form

Crucially, the difference between and is precisely . Thus, by maximizing the evidence-lower bound, we are minimizing by “squeezing” it between and .

On the choice of KL divergence

To recap, we have just defined an optimization objective for variational inference (the variational lower bound) and we have shown that maximizing the lower bound leads to minimizing the divergence .

Recall how we said earlier that ; both divergences equal zero when , but assign different penalties when . This raises the question: why did we choose one over the other and how do they differ?

Perhaps the most important difference is computational: optimizing involves an expectation with respect to , while requires computing expectations with respect to , which is typically intractable even to evaluate.

However, choosing this particular divergence affects the returned solution when the approximating family does not contain the true . Observe that — which is called the I-projection or information projection — is infinite if and :

This means that if we must have . We say that is zero-forcing for and it will typically under-estimate the support of

On the other hand, — known as the M-projection or the moment projection — is infinite if and . Thus, if we must have . We say that is zero-avoiding for and it will typically over-estimate the support of .

The figure below illustrates this phenomenon graphically.

Fitting a unimodal approximating distribution q (red) to a multimodal p (blue). Using KL(p||q) leads to a q that tries to cover both modes (a). However, using KL(q||p) forces q to choose one of the two modes of p (b, c).

Due to the properties that we just described, we often call the inclusive KL divergence, while is the exclusive KL divergence.

Mean-field inference

The next step in our development of variational inference concerns the choice of approximating family . The machine learning literature contains dozens of proposed ways to parametrize this class of distributions; these include exponential families, neural networks, Gaussian processes, latent variable models, and many others types of models.

However, one of the most widely used classes of distributions is simply the set of fully-factored ; here each is categorical distribution over a one-dimensional discrete variable, which can be described as a one-dimensional table.

This choice of turns out to be easy to optimize over and works surprisingly well. It is perhaps the most popular choice when optimizing the variational bound; variational inference with this choice of is called mean-field inference. It consists in solving the following optimization problem:

The standard way of performing this optimization problem is via coordinate descent over the : we iterate over and for each we optimize over while keeping the other “coordinates” fixed.

Interestingly, the optimization problem for one coordinate has a simple closed form solution:

Notice that both sides of the above equation contain univariate functions of : we are thus replacing with another function of the same form. The constant term is a normalization constant for the new distribution.

Notice also that on right-hand side, we are taking an expectation of a sum of factors

Of these, only factors belonging to the Markov blanket of are a function of (simply by the definition of the Markov blanket); the rest are constant with respect to and can be pushed into the constant term.

This leaves us with an expectation over a much smaller number of factors; if the Markov blanket of is small (as is often the case), we are able to analytically compute . For example, if the variables are discrete with possible values, and there are factors and variables in the Markov blanket of , then computing the expectation takes time: for each value of we sum over all assignments of the variables, and in each case, we sum over the factors.

The result of this is a procedure that iteratively fits a fully-factored that approximates in terms of . After each step of coordinate descent, we increase the variational lower bound, tightening it around .

In the end, the factors will not quite equal the true marginal distributions , but they will often be good enough for many practical purposes, such as determining .


Index Previous Next
Variational inference - Volodymyr Kuleshov