3.1.22.12. unit_scaling.functional.residual_add

unit_scaling.functional.residual_add(residual: Tensor, skip: Tensor, tau: float = 1.0) Tensor[source]

Adds a residual connection and skip connection together, with a relative weighting tau applied to the residual branch. Should be used in conjunction with unit_scaling.functional.residual_split().

Parameters:
  • residual (Tensor) – the tensor coming out of the residual connection.

  • skip (Tensor) – the tensor coming out of the skip connection.

  • tau (float, optional) – the ratio of scale of contributions of the residual branch to the skip connection. Larger values favor skip over residual. Defaults to 1 (equal contribution).

Returns:

the result of the combined residual and skip tensors.

Return type:

Tensor