3.6.8. unit_scaling.transforms.unit_scale

unit_scaling.transforms.unit_scale(module: M, replace: Dict[Callable[[...], Any], Callable[[...], Any]] = {}) M[source]

[Experimental] Returns a unit-scaled version of the input model.

Uses TorchDynamo to trace and transform the user-supplied module. This transformation identifies all torch.nn.functional uses in the input module, and replaces them with their unit-scaled equivalents, should they exist.

The tracing procedure automatically recurses into modules (whether defined in libraries, or by the user), identifying inner calls to any torch.nn.functional operations, to build a graph of fundamental operations. Unit scaling is then applied as a transformation on top of this graph.

This transformation proceeds in five stages:

  1. Replacement of user-defined functions according to the supplied replace dictionary.

  2. Replacement of all functions with unit-scaled equivalents defined in unit_scaling.functional.

  3. Identification & replacement of all add operations that represent residual-adds. The identification of residual connections is done via a dependency analysis on the graph. Residual-adds require special scaling compared with regular adds (see paper / User Guide for details).

  4. Unconstraining of all operations after the final residual layer. By default all unit scaled operations have their scaling factors constrained in the forward and backward pass to give valid gradients. This is not required in these final layers (see paper for proof), and hence we can unconstrain the operations to give better scaling.

  5. Unit-scaling of all weights and zero-initialisation of all biases.

Note that by using TorchDynamo, unit_scale() is able to trace a much larger set of modules / operations than with previous PyTorch tracing approaches. This enables the process of unit scaling to be expressed as a generic graph transform that can be applied to arbitrary modules.

Note that the current version of TorchDynamo (or torch.compile(), which is a wrapper around TorchDynamo) doesn’t support nested transforms, so we implement our own system here. This makes it easy to nest transforms:

from unit_scaling.transforms import compile, simulate_fp8, unit_scale

module = compile(simulate_fp8(unit_scale(module)))

However, these transforms are not interoperable with the standard torch.compile() interface.

In some cases users may have a model definition that uses a custom implementation of a basic operation. In this case, unit_scale() can be told explicitly to substitute the layer for an equivalent, using the replace dictionary:

import unit_scaling.functional as U
from unit_scaling.transforms import unit_scale

def new_gelu(x):
    ...

class Model(nn.Module):
    def forward(x):
        ...
        x = new_gelu(x)
        ...

model = unit_scale(Model(), replace={new_gelu: U.gelu})

This can also be used to substitute a particular function for a user-defined unit-scaled function not provided by unit_scaling.functional.

Note: unit_scale() is experimental and has not yet been widely tested on a range of models. The standard approach to unit scaling a model is still to manually substitute the layers/operations in a model with their unit-scaled equivalents. Having said this, unit_scale() is implemented in a sufficiently generic way that we anticipate many users will ultimately be able to rely on this graph transform alone.

Parameters:
  • module (nn.Module) – the input module to be unit scaled.

  • replace (Dict[Callable, Callable], optional) – a dictionary where keys represent functions to be replaced by the corresponding value-functions. Note that these substitutions take priority over the standard unit scaling substitutions. Defaults to dict().

Returns:

the unit scaled module (with an independent copy of parameters)

Return type:

nn.Module