5.5.10. unit_scaling.functional.residual_split

unit_scaling.functional.residual_split(input: Tensor, tau: float = 0.5) Tuple[Tensor, Tensor][source]

Splits a tensor into an residual and skip tensor, prior to being used in a residual layer, with a relative weighting tau applied to the residual branch. Should be used in conjunction with unit_scaling.functional.residual_add().

This is necessary as unit scaling delays the residual branch scaling in the backward pass such that residual gradients are still unit-scaled.

The need for a relative weighting between the two branches (tau) is a result of unit-scaling normalising the scales of the two branches. In non-unit-scaled models the two branches may have different scales, which can be beneficial to training. The tau factor allows unit scaling to behave as though the branches have different scales.

For MLP layers tau=0.5 is recommended, and for self-attention layers tau=0.01.

These values reflect the relative scales of the skip-vs-residual branches in a standard transformer. Empirically, the self-attention tau is fairly insensitive (i.e. tau=0.1 or tau=0.001 work well), but the default tau=0.5 causes significant degradation.

Parameters:
  • input (Tensor) – the tensor to which the residual layer is to be applied.

  • tau (float, optional) – the weighting of the residual branch relative to the skip connection. Defaults to 0.5.

Returns:

resulting tensors in the order: residual, skip.

Return type:

Tuple[Tensor, Tensor]