Learning in latent variable models
Up to now, we have assumed that when learning a directed or an undirected model, we are given examples of every single variable that we are trying to model.
However, that may not always be the case. Consider for example a probabilistic language model of news articlesA language model assigns probabilities to sequences of words . We can, among other things, sample from to generate various kinds of sentences. . Each article typically focuses on a specific topic , e.g. finance, sports, politics. Using this prior knowledge, we may build a more accurate model , in which we have introduced an additional, unobserved variable . This model can be more accurate, because we can now learn a separate for each topic, rather than trying to model everything with one .
However, since is unobserved, we cannot directly use the learning methods that we have so far. In fact, the unobserved variables make learning much more difficult; in this chapter, we will look at how to use and how to learn models that involve latent variables.
Latent variable models
More formally, a latent variable model (LVM) is a probability distribution over two sets of variables :
where the variables are observed at learning time in a dataset and the are never observed.
The model may be either directed or undirected. There exist both discriminative and generative LVMs, although here we will focus on the latter (the key ideas hold for discriminative models as well).
Example: Gaussian mixture models
Gaussian mixture models (GMMs) are a latent variable model that is also one of the most widely used models in machine learning.
Example of a dataset that is best fit with a mixture of two Gaussians. Mixture models allow us to model clusters in the dataset.
In a GMM, each data point is a tuple with and ( is discrete). The joint is a directed model
where for some vector of class probabilities and
is a multivariate Gaussian with mean and variance .
This model postulates that our observed data is comprised of clusters with proportions specified by ; the distribution within each cluster is a Gaussian. We can see that is a mixture by explicitly writing out this probability:
To generate a new data point, we sample a cluster and then sample its Gaussian .
Why are latent variable models useful?
There are two reasons why we might want to use latent variable models.
The simplest reason is that some data might be naturally unobserved. For example, if we are modeling a clinical trial, then some patients may drop out, and we won’t have their measurements. The methods in this chapter can be used to learn with this kind of missing data.
However, the most important reason for studying LVMs is that they enable us to leverage our prior knowledge when defining a model. Our topic modeling example from the introduction illustrates this. We know that our set of news articles is actually a mixture of distinct distributions (one for each topic); LVMs allow us to design a model that captures this.
LVMs can also be viewed as increasing the expressive power of our model. In the case of GMMs, the distribution that we can model using a mixture of Gaussian components is much more expressive than what we could have modeled using a single component.
Marginal likelihood training
How do we train an LVM? Our goal is still to fit the marginal distribution over the visible variables to that observed in our dataset . Hence our previous discussion about KL divergences applies here as well and by the same argument, we should be maximizing the marginal log-likelihood of the data
This optimization objective is considerably more difficult than regular log-likelihood, even for directed graphical models. For one, we can see that the summation inside the log makes it impossible to decompose into a sum of log-factors. Hence, even if the model is directed, we can no longer derive a simple closed form expression for the parameters.
Exponential family distributions (gray lines) have concave log-likelihoods. However, a weighted mixture of such distributions is no longer concave (black line). Looking closer at the distribution of a data point , we also see that it is actually a mixture
of distributions with weights . Whereas a single exponential family distribution has a concave log-likelihood (as we have seen in our discussion of undirected models), the log of a weighted mixture of such distributions is no longer concave or convex.
This non-convexity requires the development of specialized learning algorithms.
Learning latent variable models
Since the objective is non-convex, we will resort to approximate learning algorithms. These methods are widely used in practice and are quite effectiveNote however, that (quite surprisingly), many latent variable models (like GMMs) admit algorithms that can compute the global optimum of the maximum likelihood objective, even though it is not convex. Such methods are covered at the end of CS229T. .
The Expectation-Maximization algorithm
The Expectation-Maximization (EM) algorithm is a hugely important and widely used algorithm for learning directed latent-variable graphical models with parameters and latent .
The EM algorithm relies on two simple observations.
- If the latent were fully observed, then we could optimize the log-likelihood exactly using our previously seen closed form solution for .
- Knowing the weights, we can often efficiently compute the posterior (this is an assumption; it is not true for some models).
EM follows a simple iterative two-step strategy: given an estimate of the weights, compute and use it to “hallucinate” values for . Then, find a new by optimizing the resulting tractable objective. This process will eventually converge.
We haven’t exactly defined what we mean by “hallucinating” the data. The full definition is a bit technical, but its instantiation is very intuitive in most models like GMMs.
By “hallucinating” the data, we mean computing the expected log-likelihood
This expectation is what gives the EM algorithm half of its name. If is not too high-dimensional (e.g. in GMMs it is a one-dimensional categorical variable), then we can compute this expectation.
Since the summation is now outside the log, we can maximize the expected log-likelihood. In particular, when is a directed model, again decomposes into a sum of log-CPD terms that can be optimized independently, as discussed in the chapter on directed graphical models.
We can formally define the EM algorithm as follows. Let be our dataset.
- Starting at an initial , repeat until convergence for :
- E-Step: For each , compute the posterior .
- M-Step: Compute new weights via
Example: Gaussian mixture models
Let’s look at this algorithm in the context of GMMs. Suppose we have a dataset . In the E-step, we may compute the posterior for each data point as follows:
Note that each is simply the probability that originates from component given the current set of parameters . After normalization, these form the -dimensional vector of probabilities .
Recall that in the original model, is an indicator variable that chooses a component for ; we may view this as a “hard” assignment of to one component. The result of the step is a -dimensional vector (whose components sum to one) that specifies a “soft” assignment to components. In that sense, we have “hallucinated” a “soft” instantiation of ; this is what we meant earlier by an “intuitive interpretation” for .
At the M-step, we optimize the expected log-likelihood of our model.
We can optimize each of these terms separately. We will start with . We have to find that maximize
where is a constant that does not depend on and is a probability distribution defined over as
Now we know that is optimized when equals (as discussed in the section on learning directed models, this objective equals the KL divergence between and , plus a constant). Moreover, since is in the exponential family, it is entirely described by its sufficient statistics (recall our discussion of exponential families in the section on learning undirected models). Thus, we may set the mean and variance to those of , which are
Note how these are the just the mean and variance of the data, weighted by their cluster affinities! Similarly, we may find out that the class priors are
Although we have derived these results using general facts about exponential families, it’s equally possible to derive them using standard calculus techniques.
EM as variational inference
Why exactly does EM converge? We can understand the behavior of EM by casting it in the framework of variational inference.
Consider the posterior inference problem for , where the variables are held fixed as evidence. We may apply our variational inference framework by taking to be the unnormalized distribution; in that case, will be the normalization constant.
Recall that variational inference maximizes the evidence lower bound (ELBO)
over distributions . The ELBO satisfies the equation
Hence, is maximized when ; in that case the KL term becomes zero and the lower bound is tight:
The EM algorithm can be seen as iteratively optimizing the ELBO over (at the E step) and over (at the M) step.
Starting at some , we compute the posterior at the step. We evaluate the ELBO for ; this makes the ELBO tight:
Next, we optimize the ELBO over , holding fixed. We solve the problem
Note that this is precisely the optimization problem solved at the step of EM (in the above equation, there is an additive constant independent of ).
Solving this problem increases the ELBO. However, since we fixed to , the ELBO evaluated at the new is no longer tight. But since the ELBO was equal to before optimization, we know that the true log-likelihood must have increased.
We now repeat this procedure, computing (the E-step), plugging into the ELBO (which makes the ELBO tight), and maximizing the resulting expression over . Every step increases the marginal likelihood , which is what we wanted to show.
Properties of EM
From our above discussion, it follows that EM has the following properties:
- The marginal likelihood increases after each EM cycle.
- Since the marginal likelihood is upper-bounded by its true global maximum, and it increases at every step, EM must eventually converge.
However, since we optimizing a non-convex objective, we have no guarantee to find the global optimum. In fact, EM in practice converges almost always to a local optimum, and moreover, that optimum heavily depends on the choice of initialization. Different initial can lead to very different solutions, and so it is very common to use multiple restarts of the algorithm and choose the best one in the end. In fact EM is so sensitive to the choice of initial parameters, that techniques for choosing these parameters are still an active area of research.
In summary, the EM algorithm is a very popular technique for optimizing latent variable models that is also often very effective. Its main downside are its difficulties with local minima.