3.8. unit_scaling.utils

Utility functions for developing unit-scaled models.

Functions

analyse_module(module, inputs[, backward, ...])

Given a nn.Module and dummy forward and backward tensors, generates code representing the module annotated with the scales (standard deviation) of each tensor in both forward and backward passes.

Classes

ScalePair([forward, backward])

Dataclass containing a pair of scalars, intended to represent the standard deviation of an arbitrary tensor in the forward and backward passes.

ScaleTracker(*args, **kwargs)

Given a nn.Tensor, records its standard deviation in the forward and backward pass in the supplied ScalePair.

ScaleTrackingInterpreter(module)

Wraps an fx.GraphModule such than when executed it records the standard deviation of every intermediate nn.Tensor in the forward and backward pass.