4.1.22.14. unit_scaling.functional.residual_split
- unit_scaling.functional.residual_split(input: Tensor, tau: float = 1.0) Tuple[Tensor, Tensor] [source]
Splits a tensor into an residual and skip tensor, prior to being used in a residual layer, with a relative weighting tau applied to the residual branch. Should be used in conjunction with
unit_scaling.functional.residual_add()
.This is necessary as unit scaling delays the residual branch scaling in the backward pass such that residual gradients are still unit-scaled.
The need for a relative weighting between the two branches (tau) is a result of unit-scaling normalising the scales of the two branches. In non-unit-scaled models the two branches may have different scales, which can be beneficial to training. The tau factor allows unit scaling to behave as though the branches have different scales.
- Parameters:
- Returns:
resulting tensors in the order: residual, skip.
- Return type: