3.6.1. unit_scaling.transforms.compile
- unit_scaling.transforms.compile(module: M) M[source]
A transform that applies torch.compile to a module.
Note that this is slightly different to calling
torch.compile(module).The current version of
torch.compile()doesn’t allow for nested transforms, so the following is not supported:import torch from unit_scaling.transforms import unit_scale module = torch.compile(unit_scale(module))
unit_scaling.transformsaddresses this by introducing a range of composable transforms. This works by moving the call totorch._dynamo.optimize()within theforward()method of the module and only executing it on the first call to the module, or if a new transform is applied, the optimised call being cached thereafter.The
unit_scaling.transforms.compile()function is one such composable transform. This means that the following can be written:from unit_scaling.transforms import compile, unit_scale module = compile(unit_scale(module))
This will successfully combine the two transforms in a single module. Note that the call to compile must still come last, as its underlying backend returns a standard
torch.nn.Modulerather than atorch.fx.GraphModule.Currently
unit_scaling.transforms.compile()does not support the ops needed for theunit_scaling.transforms.simulate_fp8()transform, though this may change in future PyTorch releases.Modules implemented manually with unit-scaled layers (i.e. without the global
unit_scale(module)transform) can still usetorch.compile()in the standard way.- Parameters:
module (M) – the module to be compiled.
- Returns:
the compiled module.
- Return type:
M