3.1.20. unit_scaling.TransformerLayer
- class unit_scaling.TransformerLayer(hidden_size: int, heads: int, mhsa_tau: float, mlp_tau: float, is_causal: bool, dropout_p: float = 0.0)[source]
A unit-scaled implementation of a PreNorm (see https://arxiv.org/abs/2002.04745) transformer layer.
Warning: using constraint=None here will likely give incorrect gradients.
- Parameters:
hidden_size (int) – the hidden dimension size of the input.
heads (int) – the number of attention heads.
mhsa_tau (float) – the weighting of the multi-head-self-attention branch relative to the skip connection.
mlp_tau (float) – the weighting of the MLP branch relative to the skip connection.
is_causal (bool) – causal masking (for non-padded sequences).
dropout_p (float, optional) – the probability of residual and post-softmax dropout.