3.6.2. unit_scaling.transforms.prune_non_float_tensors
- unit_scaling.transforms.prune_non_float_tensors(graph: Graph) Graph [source]
Given an FX Graph, prunes all nodes which don’t output floating-point tensors.
The supplied graph must have been generated via the
module.scales_graph()
method, called on a module withunit_scaling.transforms.track_scales()
applied. This is necessary as the scale-tracking process is what identifies which tensors have floating-point values. E.g.from unit_scaling.transforms import track_scales, prune_non_float_tensors inpt = ... model = ... model = track_scales(model) loss = model(inpt) loss.backward() graph = model.scales_graph() pruned_graph = prune_non_float_tensors(graph)
- Parameters:
graph (Graph) – the FX graph to be pruned.
- Returns:
the pruned graph containing only nodes outputting floating-point tensors.
- Return type:
Graph