3.6. unit_scaling.transforms
Useful torch dynamo transforms of modules for the sake of numerics and unit scaling.
Functions
|
A transform that applies torch.compile to a module. |
|
Given an FX Graph, prunes all nodes which don't output floating-point tensors. |
|
Given an FX Graph, prunes all nodes with the same scale as the previous node. |
|
Given an FX Graph, prunes all nodes with functions in the set of target nodes. |
|
[Experimental] Given a module, uses TorchDynamo to return a new module which simulates the effect of using the supplied formats for matmuls. |
|
[Experimental] Given a module, uses TorchDynamo to return a new module which simulates the effect of running matmuls in FP8. |
|
Returns a version of the input module which tracks internal tensor metrics. |
|
[Experimental] Returns a unit-scaled version of the input model. |
Classes
|
A set of metrics representing useful information about a tensor, in the forward and backward pass. |