gated MLP: Nonlinear Function
Created: March 03, 2024
Modified: March 03, 2024

gated MLP

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

References:

Where a standard MLP layer is a linear transformation followed by a nonlinearity:

f(x)=σ(xW+b)f(x) = \sigma(xW + b)

a gated MLP does two linear transformations, applies a nonlinearity to one of them, and multiplies the results (creating a multiplicative interaction):

f(x)=σ(xW+b)(xV+c)f(x) = \sigma(xW + b) \otimes (xV + c)

Equivalently, it does one double-wide transformation, then multiplies one half of the result by a nonlinear transformation of the other half.

In an LSTM-style gate we would effectively have nonlinearities on both factors, e.g., f(x)=tanh(xW+b)σ(xV+c)f(x) = \text{tanh}(xW + b) \otimes \sigma(xV + c). Compared to this, the gated MLP can be motivated as having better gradients because it allows for a linear path.

In the context of transformers, Noam Shazeer (https://arxiv.org/abs/2002.05202, 2020) showed that using a gated MLP for the first layer of the feedforward block can improve performance. This works even without the nonlinearity (he calls this a 'bilinear layer') but the best performance comes from a ReLU or a SiLU/Swish nonlinearity: Screenshot 2024 03 03 at 12 49 09 PM Why would we expect gated layers to work well? Noam Shazeer famously doesn't speculate: "We offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence."


Traditional transformer FFN blocks can be interpreted as key-value stores. The first layer xW1xW_1 computes dot products of the input 'query' xx with keys stored as columns of the weight matrix W1W_1; the output of the block is then a linear combination of 'value' vectors stored in W2W_2, following weights given by the key-query similarity scores (effectively 'normalized' with some sort of nonlinearity). Does using a gated layer change this intuition?


Observation: since GeLU(x)=xΦ(x)\text{GeLU}(x) = x\Phi(x), a gated MLP with GeLU nonlinearity can decompose as

f(x)=GeLU(xW+b)(xV+c)=Φ(xW+b)(xW+b)(xV+c)\begin{align*} f(x) &= \text{GeLU}(xW + b) \otimes (xV + c)\\ &= \Phi(xW + b) \otimes (xW + b) \otimes (xV + c) \end{align*}