3.1.14. unit_scaling.MHSA

class unit_scaling.MHSA(hidden_size: int, heads: int, is_causal: bool, dropout_p: float = 0.0, mult: float = 1.0)[source]

A unit-scaled implementation of a multi-head self-attention layer.

Parameters:
  • hidden_size (int) – the hidden dimension size of the input.

  • heads (int) – the number of attention heads.

  • is_causal (bool) – causal masking (for non-padded sequences).

  • dropout_p (float, optional) – the probability of the post-softmax dropout.

  • mult (float, optional) – a multiplier to be applied to change the shape of a nonlinear function. Typically, high multipliers (> 1) correspond to a ‘sharper’ (low temperature) function, while low multipliers (< 1) correspond to a ‘flatter’ (high temperature) function.