3.6.7. unit_scaling.transforms.track_scales

unit_scaling.transforms.track_scales(module: M) M[source]

Returns a version of the input module which tracks internal tensor metrics.

When the forward() and backward() methods of the resulting module are called, internally various metrics (such as scale) are recorded for each intermediate tensor used. These can be accessed using an additional method module.scales_graph() which is added to the module. The returned object is an instance of torch.fx.Graph, where each node representing a floating-point tensor now has a node.meta["metrics"] object of type unit_scaling.transforms.Metrics associated with it. Note that if forward() or backward() are not called, tensor metrics will not be available.

The unit scaling library also provides a method to visualise FX Graphs, via the unit_scaling.analysis.plot() function. This is intended to be used as follows:

from unit_scaling.transforms import track_scales
from unit_scaling.analysis import plot

inpt = ...
model = ...

model = track_scales(model)
loss = model(inpt)
loss.backward()

graph = model.scales_graph()
plot(graph)

The inpt tensor(s) provided to any model transformed by track_scales() will automatically have inpt.requires_grad_() set (this is required for full scale tracking in the backward pass), so the user need not specify this.

track_scales() can be used in conjunction with other graph transforms provided, but should always be the final transform in a chain. E.g.

from unit_scaling.transforms import simulate_fp8, track_scales, unit_scale

model = track_scales(unit_scale(simulate_fp8(model)))

The full FX graph returned by this transform may contain more information than the user requires for the sake of analysis. For this reason the functions unit_scaling.transforms.prune_non_float_tensors() and unit_scaling.transforms.prune_same_scale_tensors() are provided, which in practice tend to limit the graph to only key tensors.

Parameters:

module (M) – the input module to be tracked.

Returns:

a new version of the input module which tracks tensor metrics when used.

Return type:

M