3.1.22.14. unit_scaling.functional.residual_split

unit_scaling.functional.residual_split(input: Tensor, tau: float = 1.0) Tuple[Tensor, Tensor][source]

Splits a tensor into an residual and skip tensor, prior to being used in a residual layer, with a relative weighting tau applied to the residual branch. Should be used in conjunction with unit_scaling.functional.residual_add().

This is necessary as unit scaling delays the residual branch scaling in the backward pass such that residual gradients are still unit-scaled.

The need for a relative weighting between the two branches (tau) is a result of unit-scaling normalising the scales of the two branches. In non-unit-scaled models the two branches may have different scales, which can be beneficial to training. The tau factor allows unit scaling to behave as though the branches have different scales.

Parameters:
  • input (Tensor) – the tensor to which the residual layer is to be applied.

  • tau (float, optional) – the ratio of scale of contributions of the residual branch to the skip connection. Values larger than one favor skip over residual. Defaults to 1 (equal contribution).

Returns:

resulting tensors in the order: residual, skip.

Return type:

Tuple[Tensor, Tensor]