3.6.7. unit_scaling.transforms.track_scales
- unit_scaling.transforms.track_scales(module: M) M [source]
Returns a version of the input module which tracks internal tensor metrics.
When the
forward()
andbackward()
methods of the resulting module are called, internally various metrics (such as scale) are recorded for each intermediate tensor used. These can be accessed using an additional methodmodule.scales_graph()
which is added to the module. The returned object is an instance oftorch.fx.Graph
, where each node representing a floating-point tensor now has anode.meta["metrics"]
object of typeunit_scaling.transforms.Metrics
associated with it. Note that ifforward()
orbackward()
are not called, tensor metrics will not be available.The unit scaling library also provides a method to visualise FX Graphs, via the
unit_scaling.analysis.plot()
function. This is intended to be used as follows:from unit_scaling.transforms import track_scales from unit_scaling.analysis import plot inpt = ... model = ... model = track_scales(model) loss = model(inpt) loss.backward() graph = model.scales_graph() plot(graph)
The
inpt
tensor(s) provided to any model transformed bytrack_scales()
will automatically haveinpt.requires_grad_()
set (this is required for full scale tracking in the backward pass), so the user need not specify this.track_scales()
can be used in conjunction with other graph transforms provided, but should always be the final transform in a chain. E.g.from unit_scaling.transforms import simulate_fp8, track_scales, unit_scale model = track_scales(unit_scale(simulate_fp8(model)))
The full FX graph returned by this transform may contain more information than the user requires for the sake of analysis. For this reason the functions
unit_scaling.transforms.prune_non_float_tensors()
andunit_scaling.transforms.prune_same_scale_tensors()
are provided, which in practice tend to limit the graph to only key tensors.- Parameters:
module (M) – the input module to be tracked.
- Returns:
a new version of the input module which tracks tensor metrics when used.
- Return type:
M