Almost-scaled dot-product attention
TL;DR: Scaled dot product attention isn’t properly scaled, and that’s a good thing!
Notebook: almost-scaled dot-product attention
Transformers seem to be all you need, but we don’t fully understand why they work so well. While working on unit scaling, we noticed something surprising about attention, the heart of the transformer architecture, and how the outputs are scaled.
Many deep learning modules are designed and initialised to roughly preserve variance in the forward and/or backward (gradient) passes. This is a useful property as the behaviour of many modules depends on the scale of their inputs (e.g. saturating nonlinearities). Dot product attention explicitly includes a scaling factor for this to ensure the variance going into the softmax is stable:
But this is insufficient for the attention operation as a whole. We have derived a post-scaling factor for attention to correct this:
Where $d_{seq}$ is the sequence length. For example, this gives the following scaling behaviour:
In this post, we’ll look at the variance-scaling behaviour of attention, and explain this scaling factor, before seeing that it makes training dynamics worse, not better. The post is a condensed summary of our almost-scaled dot-product attention notebook.
Where does $(d_{seq}/e)^{1/2}$ come from?
Attention contains the expression $Z=\mathrm{Softmax}(A^{\prime})V$. If we modify this slightly to introduce a temperature $t$, $Z=\mathrm{Softmax}(A^{\prime}/t)V$, we can think about three cases (assuming $V \sim N(0, 1)$):
- $t\to \infty$, the scale of $Z$ is $d_{seq}^{-1/2}$ — the softmax output is flat with all values $= d_{seq}^{-1}$, followed by a sum over $d_{seq}$ uncorrelated values which scales up by $d_{seq}^{1/2}$
- $t\to 0$, the scale of $Z$ is $1$ and the output is a single unit spike — attention selects a single element of $V$
- $t \gt 1/2$, the scale of $Z$ is $(e^{t^{-2}}/d_{seq})^{1/2}$ and with some assumptions, the output follows a log-normal distribution — we explain this further in the companion notebook
We find that the log-normal scaling rule works well for temperature near 1, so propose multiplying by the inverse, i.e. scale attention output by $(d_{seq}/e)^{1/2}$.
Does it work? …No!
We tested this change, introducing “fully scaled attention” in a full transformer model—a small autoregressive character language model trained on Shakespeare. This is what we saw from a learning rate sweep:
This is most unfortunate. It seems that under-scaled tensors coming out of the attention block are important and helpful for transformer training dynamics. It isn’t just tiny Shakespare models—we’ve also seen this effect when training BERT. We don’t yet have an explanation for this difference, but find it intriguing that such a (presumed) accident of under-scaling turns out to be helpful for training dynamics!
Unit scaling has a solution for this, allowing unit-scaled tensors while retaining the original training dynamics. The bad training behaviour must come from scale-dependent operations, in particular when attention’s residual output is added to the skip connection. So, we found that we can reproduce the same dynamics as the original model by applying a relative weight to the residual vs skip connections.
Conclusion
It is helpful to think through the scales of tensors in deep learning models. Indeed, careful reasoning about scale is the core principle underpinning unit scaling (which also considers the scale of gradients, not just activations).
In the above example, we saw how to “fix” attention’s scaling behaviour, multiplying the outputs by $(d_{seq}/e)^{1/2}$, so that the outputs are unit-variance. However we also saw that this change can make training dynamics worse, not better. Why this happens is, as far as we know, an open question.
If you’re interested to find out more, check out our accompanying notebook and unit scaling paper.
With thanks to Charlie Blake for help & feedback.
— Douglas Orr (douglaso@graphcore.ai), October 2023
Comments