3.6. unit_scaling.transforms

Useful torch dynamo transforms of modules for the sake of numerics and unit scaling.

Functions

compile(module)

A transform that applies torch.compile to a module.

prune_non_float_tensors(graph)

Given an FX Graph, prunes all nodes which don't output floating-point tensors.

prune_same_scale_tensors(graph[, rtol])

Given an FX Graph, prunes all nodes with the same scale as the previous node.

prune_selected_nodes(graph, targets)

Given an FX Graph, prunes all nodes with functions in the set of target nodes.

simulate_format(module, fwd_format, bwd_format)

[Experimental] Given a module, uses TorchDynamo to return a new module which simulates the effect of using the supplied formats for matmuls.

simulate_fp8(module)

[Experimental] Given a module, uses TorchDynamo to return a new module which simulates the effect of running matmuls in FP8.

track_scales(module)

Returns a version of the input module which tracks internal tensor metrics.

unit_scale(module[, replace])

[Experimental] Returns a unit-scaled version of the input model.

Classes

Metrics(fwd_tensor)

A set of metrics representing useful information about a tensor, in the forward and backward pass.