3.1.14. unit_scaling.MHSA
- class unit_scaling.MHSA(hidden_size: int, heads: int, is_causal: bool, dropout_p: float = 0.0, mult: float = 1.0)[source]
A unit-scaled implementation of a multi-head self-attention layer.
- Parameters:
hidden_size (int) – the hidden dimension size of the input.
heads (int) – the number of attention heads.
is_causal (bool) – causal masking (for non-padded sequences).
dropout_p (float, optional) – the probability of the post-softmax dropout.
mult (float, optional) – a multiplier to be applied to change the shape of a nonlinear function. Typically, high multipliers (> 1) correspond to a ‘sharper’ (low temperature) function, while low multipliers (< 1) correspond to a ‘flatter’ (high temperature) function.