3.5.1. unit_scaling.scale.scale_bwd

unit_scaling.scale.scale_bwd(input: Tensor, scale: float) Tensor[source]

Applies a scalar multiplication to a tensor in only the backward pass.

Parameters:
  • input (Tensor) – the tensor to be scaled.

  • scale (float) – the scale factor applied to the tensor in the backward pass.

Returns:

unchanged in the forward pass, but with a scaled grad.

Return type:

Tensor