3.6.1. unit_scaling.transforms.compile
- unit_scaling.transforms.compile(module: M) M [source]
A transform that applies torch.compile to a module.
Note that this is slightly different to calling
torch.compile(module)
.The current version of
torch.compile()
doesn’t allow for nested transforms, so the following is not supported:import torch from unit_scaling.transforms import unit_scale module = torch.compile(unit_scale(module))
unit_scaling.transforms
addresses this by introducing a range of composable transforms. This works by moving the call totorch._dynamo.optimize()
within theforward()
method of the module and only executing it on the first call to the module, or if a new transform is applied, the optimised call being cached thereafter.The
unit_scaling.transforms.compile()
function is one such composable transform. This means that the following can be written:from unit_scaling.transforms import compile, unit_scale module = compile(unit_scale(module))
This will successfully combine the two transforms in a single module. Note that the call to compile must still come last, as its underlying backend returns a standard
torch.nn.Module
rather than atorch.fx.GraphModule
.Currently
unit_scaling.transforms.compile()
does not support the ops needed for theunit_scaling.transforms.simulate_fp8()
transform, though this may change in future PyTorch releases.Modules implemented manually with unit-scaled layers (i.e. without the global
unit_scale(module)
transform) can still usetorch.compile()
in the standard way.- Parameters:
module (M) – the module to be compiled.
- Returns:
the compiled module.
- Return type:
M