3.6.2. unit_scaling.transforms.prune_non_float_tensors

unit_scaling.transforms.prune_non_float_tensors(graph: Graph) Graph[source]

Given an FX Graph, prunes all nodes which don’t output floating-point tensors.

The supplied graph must have been generated via the module.scales_graph() method, called on a module with unit_scaling.transforms.track_scales() applied. This is necessary as the scale-tracking process is what identifies which tensors have floating-point values. E.g.

from unit_scaling.transforms import track_scales, prune_non_float_tensors

inpt = ...
model = ...

model = track_scales(model)
loss = model(inpt)
loss.backward()

graph = model.scales_graph()
pruned_graph = prune_non_float_tensors(graph)
Parameters:

graph (Graph) – the FX graph to be pruned.

Returns:

the pruned graph containing only nodes outputting floating-point tensors.

Return type:

Graph