3.5.2. unit_scaling.scale.scale_fwd

unit_scaling.scale.scale_fwd(input: Tensor, scale: float) Tensor[source]

Applies a scalar multiplication to a tensor in only the forward pass.

Parameters:
  • input (Tensor) – the tensor to be scaled.

  • scale (float) – the scale factor applied to the tensor in the forward pass.

Returns:

scaled in the forward pass, but with its original grad.

Return type:

Tensor