mode-covering variational inference is incoherent: Nonlinear Function
Created: March 14, 2022
Modified: March 14, 2022

mode-covering variational inference is incoherent

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

I have a strong opinion weakly held that doesn't seem to be wildly shared in the approximate Bayesian inference community: reverse (or 'mode-seeking') KL divergence is the "right" objective for variational inference; other objectives are essentially dead ends.

The temptation of other objectives, such as forward ('mode-covering') KL, Renyi divergences including χ2\chi^2 VI, various α\alpha- and ff-divergences, etc. is that they promise to cover the entire posterior rather than just a single mode. They will tend overestimate uncertainty rather than underestimating it. This might be desirable behavior for many applications (including AI safety) where we prefer an underconfident system to an overconfident system.

A roughly equivalent fact is that mode-seeking objectives produces lower bounds on the model evidence, because the approximate posterior considers only a subset of explanations for the data, while mode-covering objectives produce upper bounds, because the approximate posterior tends to cover all possible explanations for the data and also some impossible ones. It's natural to want both sorts of bounds to "sandwich" the true evidence.

But there are two fundamental problems with mode-covering objectives. First, they can't be computed or optimized reliably in general. And even if they could, the resulting 'posteriors' are not useful for real-world decision making.

Both problems stem fundamentally from the nature of reasoning about uncertainty in high-dimensional spaces (i.e., the real world): there are exponentially many ways the world could be, and we don't have the computational resources to consider all of them. The success of methods such as monte carlo tree search reflects the importance of selectively attending to high-probability states in order to get anything done; computation is important and we can't afford to waste it on low-probability states.

In high dimensions, a mode-covering qϕq_\phi will put the vast bulk of its mass on hypotheses with near-zero posterior probability. Therefore any finite number of samples we draw will be a very poor representation of the true posterior. They will include a lot of low-probability hypotheses, but may miss still miss the most explanatory hypotheses because, although covered by qϕq_\phi, these modes represent a relatively small fraction of its mass.

This failure of finite-sample approximation means that mode-covering's qϕq_\phi can't be stably optimized. In stochastic optimization, our objective (the divergence) must in general be approximated by evaluation at specific sample points. The proposal distribution for these points define the 'flashlight beam' that we shine onto the density landscape to direct our attention towards a subset of hypotheses; our sample objective can only depend on how much we "like what we see" under that beam. By assumption we can't sample from the true posterior p(zx)p(z|x), so the natural remaining proposal distribution is our current estimate of the approximate posterior qϕ(z)q_\phi(z). The problem is that mode-covering divergences depend very strongly on evaluating the parts of the space that we can't see using the current qϕq_\phi. Mode-covering divergences need to guarantee not just that we have found an area of high density, but that there are no other uncovered high-density areas anywhere in the parameter space: we would need to look outside of the "flashlight beam" in order to know whether we are shining it in the right place. This is inherently impossible. Even if we learn a very diffuse qϕq_\phi that in principle does cover the entire space, we will not actually get such coverage from any finite number of samples. Fundamentally: a mode-covering divergence depends on all points in the space, and a high-dimensional space has exponentially many points, so there is no general way to tractably evaluate such a divergence.

Potential counterexample: generative models

On the other hand, maximum likelihood training of generative models is equivalent to forward KL, and after a few years of refinement it is starting to work well: diffusion model models are matching the sample quality of GANs while covering much more of the data distribution. Does this disprove my argument? We might expect that any form of model misspecification should make these models exponentially likely to sample things that are not part of the data distribution, but in fact, they seem to fit quite well.

The difference is that because these models are trained on empirical data, the attention of the training process is kept tightly focused on the actual data distribution. The results show that our models really can represent complex high-dimensional distributions, but we need to train them using samples from said distributions. Sampling from an untrained model (as you'd do in a typical variational training procedure) is still exponentially unlikely to attend to anything useful. Note that importance sampling doesn't fix this, since on its own it can only reweight existing samples.

This implies that we might be able to fit mode-covering surrogate posteriors if only we could generate high-quality posterior samples to train on. This is kind of circular: presumably the reason we're doing VI is because we can't already generate such samples. But it could work to use MCMC steps to direct our attention incrementally towards the true posterior, since we can avoid the curse of dimensionality with gradient-based methods such as HMC. This is essentially the argument for why resample-move particle filtering can avoid degeneracy.

TODO: where do wake-sleep methods fit into this?

Thoughts

Q: what about generative flow networks?

None of this is to deny that overdispersed posteriors and upper-bound minimization might be useful in 'tame' applications with simple probabilistic models, where the posterior is low-dimensional and/or unimodal. But if VI is to be useful for general intelligence (and I'm not sure that it will, again because computation is important and that argues against explicit models of uncertainty), then seeking overdispersed posteriors is not going to be tractable.