sparse mixture of experts: Nonlinear Function
Created: February 13, 2023
Modified: February 13, 2023

sparse mixture of experts

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

References:

A sparse-gated mixture of experts layer has 'expert networks' E1,,EnE_1, \ldots, E_n and a 'gating network' GG. The output of the layer is

y=iG(x)iEi(x).y = \sum_i G(x)_i E_i(x).

The gating network in Shazeer et al. (2017) is a softmax applied to noisy logits that have been explicitly sparsified by a 'KeepTopK' operation. This is trained by simple backpropagation. Although sparsification destroys gradients, this is allegedly not a big deal in practice since gradients do pass through to the top-K experts. Although only a few experts will fire for any particular input, an additional 'load-balancing' loss term is added to encourage all experts to have equal gating values on average (in effect this encourages exploration of less-commonly-used experts); this avoids a 'rich-get-richer' effect where the most commonly selected experts receive more training and so are even more likely to be selected in the future while other experts go unused.

Hard vs soft gating: if the gate probabilities were not sparse, we would need to either compute the full weighted output (evaluating every expert) or sample a single expert from the gating distribution, treating the sample as a discrete latent variable and using REINFORCE / policy gradient methods to update the gating probabilities. The cool thing about Shazeer's sparse gating is that it pulls a sort of reparameterization trick: the randomness (Gaussian noise added to the logits before sparsifying them) is data-independent, so we don't need the complexity of policy gradient methods --- we are in a sense still technically doing 'soft' gating --- but we still get the computational benefit of the discrete choice from the sparsity. This is also an example of a perturb-and-MAP method.

Is there a free lunch in this switch / do we actually avoid the variance from hard gating?

Parallelization: it is tricky to parallelize a mixture-of-experts model since different batch elements will need different experts. The approach of Shazeer et al. (2017) is essentially to reshuffle the inputs after the gating network, sending each batch element to the device(s) that 'host' the appropriate expert(s). Each expert lives on a specific device, and all inputs that require that expert are routed to that device so that the expert can be applied to a batch of inputs.

The switch transformer (Fedus, Zoph, Shazeer, 2022) adopts this approach to transformer language models, where the 'experts' are the feedforward layers of the transformer. While previous work conjectured that it's necessary to use K2K \ge 2 experts in order to get a useful training signal for the gating network, the switch transformer finds that K=1K=1 actually works well.

The Generalist Language Model (GLaM) uses a similar architecture to the switch transformer, but with K=2K=2 experts activated per example (and trains a decoder-only language model rather than an encoder-decoder sequence-to-sequence model). Every other feedforward layer is a mixture of experts. The model is scaled to 64 experts, in which each feedforward pass has a 'base dense size' of 64B parameters (half of which are mixture-of-two experts layers, so 96B parameters are activated). There are a bunch of details about how this is distributed and balanced, which seems to be a lot of the effort since these models are more complex to scale than dense models. The claim is that this achieves better performance than GPT-3 with 1/3 the training energy consumption.

Expert choice routing (Zhou et al., 2022) tries to do a better job of keeping all experts at full utilization. Instead of choosing the top kk experts for each token --- which can result in some experts getting more tokens than others --- it chooses the top kk tokens for each expert. This guarantees that the experts are balanced by construction. It's weird, however, because it means that the number (and identity) of experts used by a given token is not fixed; it depends on the other tokens in the batch.