xLSTM: Extended Long Short-Term Memory

2 minute read

The key idea

Recurrent neural networks based on Long Short-Term Memory units were the backbone of NLP models before the advent of the now-ubiquitous transformer. This work seeks to close the gap between LSTM and transformer in the crucial model-scaling regime of LLMs. They do this by extending the LSTM in two ways to create sLSTM and mLSTM, then incorporating these layers into a deep residual architecture, called xLSTM.

Scaling trends for two variants of xLSTM (xLSTM[7:1] and xLSTM[1:0]) vs Llama, Mamba and RWKV-4, for models of 125M to 1.3B parameters. The xLSTM lines are similar and are roughly parallel to Mamba, which is parallel to Llama, then RWKV (in descending order of performance).

Their method

We’ll focus on the mLSTM variant, as the sLSTM variant is omitted from many of the best-performing models in their results. I think the best way to understand the architecture is to stare at a wall of maths for a while:

mLSTM definition. Define input, forget and output gates as linear functions of the input + bias and activation (the activation is exponential for the input gate, sigmoid for the output and either exponential or sigmoid for the output gate). Define query, key and value as linear functions of the input. The cell state is updated from forget gate times previous plus input gate times the value-key outer product. There is also a normaliser that is similar but without the value. The output is output gate times the cell-query inner product divided by the maximum of normaliser-query inner product and 1.

To give an intuition for this, there’s:

  • Inputs $\mathbf{x}$ and parameters $\mathbf{W_{q,k,v,o}}$, $\mathbf{b_{q,k,v,o}}$, $\mathbf{w_{i,f}}$, $\mathbf{b_{i,f}}$.
  • Six linear + activation ops, depending only on the inputs: $\textbf{q}, \textbf{k}, \textbf{v}, i, f, \textbf{o}$. The $f$ (forget) and $\textbf{o}$ (output) gates have sigmoid activation, giving outputs in the range $[0, 1]$, but $i$ (input) has an exponential activation. $\textbf{q}, \textbf{k}, \textbf{v}$ are linear.
  • A “cell” $\textbf{C}$: a decayed and weighted sum of $\textbf{v} \textbf{k}^\top$ (which I’ll call KV mapping) over time. At each step, the state is decayed according to the forget gate $f$ and the KV mapping is weighted according to the input gate $i$. The cell maps queries to values by matching them against keys.
  • A normalizer $\textbf{n}$: similar, but sums just $\textbf{k}$ instead of KV mapping.
  • An output $\textbf{o}$, the inner product of query $\textbf{q}$ and cell, divided by the magnitude of the inner product of $\textbf{q}$ and normaliser, and multiplied by the output gate.

Like softmax dot product self-attention, this involves a normalised sum of exponentials; a key difference is that the input to exp depends only on the “source” (key, value), not on the “target” (query). It bears some similarities to linear attention, Mamba and RWKV, permitting a parallel scan over the inputs since time dependency is linear. It retains the RNN’s advantage of summarising the context in a fixed-size representation, $\textbf{C}$, for efficient autoregressive inference.

In the xLSTM architecture, this is used in a custom residual block that performs positionwise up projection before the multi-headed mLSTM.

Results

Downstream results for LLMs of up to 1.3B parameters, trained on 300B SlimPajama tokens:

Results for xLSTM with pure mLSTM and 7:1 mLSTM:sLSTM ratios, against baselines of RWKV, Llama and Mamba on multiple downstream tasks and for models up to 1.3B. The pure mLSTM version of xLSTM performs best in most cases, across SlimPajama validation perplexity, LAMBADA, HellaSwag, PIQA, ARC and WinoGrande.

(I haven’t been able to confirm if these are zero-shot or few-shot results.) Here, xLSTM[1:0] uses only the mLSTM layer described above, while xLSTM[7:1] includes 7 mLSTM layers per 1 sLSTM layer. These results appear to demonstrate the sufficiency of mLSTM for LLMs. The paper also includes a helpful set of ablations and synthetic tasks.

Takeaways

It’s refreshing to see non-transformer LLMs trained at scale, and that the xLSTM architecture appears competitive with transformers. More research could help us understand the benefits of these alternatives, and whether the scaling properties are robust.

Comments