3.1.20. unit_scaling.TransformerLayer

class unit_scaling.TransformerLayer(hidden_size: int, heads: int, mhsa_tau: float, mlp_tau: float, is_causal: bool, dropout_p: float = 0.0)[source]

A unit-scaled implementation of a PreNorm (see https://arxiv.org/abs/2002.04745) transformer layer.

Warning: using constraint=None here will likely give incorrect gradients.

Parameters:
  • hidden_size (int) – the hidden dimension size of the input.

  • heads (int) – the number of attention heads.

  • mhsa_tau (float) – the weighting of the multi-head-self-attention branch relative to the skip connection.

  • mlp_tau (float) – the weighting of the MLP branch relative to the skip connection.

  • is_causal (bool) – causal masking (for non-padded sequences).

  • dropout_p (float, optional) – the probability of residual and post-softmax dropout.