3.1.22.13. unit_scaling.functional.residual_apply

unit_scaling.functional.residual_apply(fn: Callable[[Tensor], Tensor], input: Tensor, tau: float = 1.0) Tensor[source]

Apply a weighted residual branch, maintaining unit scale.

Combines residual_split() and residual_add() into a single function.

Parameters:
  • fn (Callable) – the residual function to apply.

  • input (Tensor) – input tensor, also to use for the skip connection.

  • tau (float, optional) – the ratio of scale of contributions of the residual branch to the skip connection. Larger values favor skip over residual. Defaults to 1 (equal contribution).