3.6.3. unit_scaling.transforms.prune_same_scale_tensors
- unit_scaling.transforms.prune_same_scale_tensors(graph: Graph, rtol: float = 1.52587890625e-05) Graph [source]
Given an FX Graph, prunes all nodes with the same scale as the previous node.
This is intended to remove non-informative nodes from the graph such as reshapes. Nodes with multiple floating-point tensors as inputs are never pruned.
Certain operations (such as slices) may change the scale slightly, but negligibly—in this case we provide a tolerance parameter which can be used to specify the relative change that is deemed significant.
The supplied graph must have been generated via the
module.scales_graph()
method, called on a module withunit_scaling.transforms.track_scales()
applied. E.g.from unit_scaling.transforms import track_scales, prune_same_scale_tensors inpt = ... model = ... model = track_scales(model) loss = model(inpt) loss.backward() graph = model.scales_graph() pruned_graph = prune_same_scale_tensors(graph)
- Parameters:
graph (Graph) – the FX graph to be pruned.
rtol (float, optional) – the relative tolerance for “same scale”. Defaults to 2**-16.
- Returns:
the pruned graph with nodes that don’t change their input scale removed.
- Return type:
Graph