3.6.3. unit_scaling.transforms.prune_same_scale_tensors

unit_scaling.transforms.prune_same_scale_tensors(graph: Graph, rtol: float = 1.52587890625e-05) Graph[source]

Given an FX Graph, prunes all nodes with the same scale as the previous node.

This is intended to remove non-informative nodes from the graph such as reshapes. Nodes with multiple floating-point tensors as inputs are never pruned.

Certain operations (such as slices) may change the scale slightly, but negligibly—in this case we provide a tolerance parameter which can be used to specify the relative change that is deemed significant.

The supplied graph must have been generated via the module.scales_graph() method, called on a module with unit_scaling.transforms.track_scales() applied. E.g.

from unit_scaling.transforms import track_scales, prune_same_scale_tensors

inpt = ...
model = ...

model = track_scales(model)
loss = model(inpt)
loss.backward()

graph = model.scales_graph()
pruned_graph = prune_same_scale_tensors(graph)
Parameters:
  • graph (Graph) – the FX graph to be pruned.

  • rtol (float, optional) – the relative tolerance for “same scale”. Defaults to 2**-16.

Returns:

the pruned graph with nodes that don’t change their input scale removed.

Return type:

Graph