3.2.3. unit_scaling.analysis.plot
- unit_scaling.analysis.plot(g: Graph, title: str = '', metric: str = 'mean_abs', prune_same_scale: bool = True, show_arrows: bool = True, show_error_bars: bool = True, show_zero_tensors: bool = False, xmin: float | None = None, xmax: float | None = None) Axes [source]
Generate a
matplotlib
plot visualising the scales in the forward (and optionally backward) pass of all tensors in an FX graph.The input graph is intended to have been generated by applying
unit_scaling.transforms.track_scales()
to an arbitrarytorch.nn.Module
, running a forward (and possibly backward) pass, then calling the module.scales_graph() function:from unit_scaling.transforms import track_scales from unit_scaling.analysis import plot inpt = ... model = ... model = track_scales(model) loss = model(inpt) loss.backward() graph = model.scales_graph() plot(graph)
Operations that don’t output floating-point tensors are automatically pruned from the visualised graph, as they are deemed unlikely to be relevant from the perspective of model numerics.
Faint coloured horizontal lines for each row represent error bars indicating the maximum and minimum values seen in each tensor during tracking.
- Parameters:
g (Graph) – the graph to visualise.
title (str, optional) – title for the generated plot. Defaults to “”.
metric (str, optional) – the metric to show on the x-axis. Can be any of: (“mean_abs”, “abs_mean”, “std”, “abs_max”, “abs_min”, “numel”). Defaults to “mean_abs”.
prune_same_scale (bool, optional) – prune operations that don’t change the scale of their input tensors. In practice this means that views / reshapes are not shown, making the resulting visualisation clearer. Defaults to True.
show_arrows (bool, optional) – show arrows between operations, denoting dependencies. Defaults to True.
show_error_bars (bool, optional) – show max/min error bars. Defaults to True.
xmin (Optional[float], optional) – the minimum x-value to display. Defaults to None.
xmax (Optional[float], optional) – the maximum x-value to display. Defaults to None.
- Returns:
the axes representing the generated plot.
- Return type:
matplotlib.axes.Axes