3.6.1. unit_scaling.transforms.compile

unit_scaling.transforms.compile(module: M) M[source]

A transform that applies torch.compile to a module.

Note that this is slightly different to calling torch.compile(module).

The current version of torch.compile() doesn’t allow for nested transforms, so the following is not supported:

import torch

from unit_scaling.transforms import unit_scale

module = torch.compile(unit_scale(module))

unit_scaling.transforms addresses this by introducing a range of composable transforms. This works by moving the call to torch._dynamo.optimize() within the forward() method of the module and only executing it on the first call to the module, or if a new transform is applied, the optimised call being cached thereafter.

The unit_scaling.transforms.compile() function is one such composable transform. This means that the following can be written:

from unit_scaling.transforms import compile, unit_scale

module = compile(unit_scale(module))

This will successfully combine the two transforms in a single module. Note that the call to compile must still come last, as its underlying backend returns a standard torch.nn.Module rather than a torch.fx.GraphModule.

Currently unit_scaling.transforms.compile() does not support the ops needed for the unit_scaling.transforms.simulate_fp8() transform, though this may change in future PyTorch releases.

Modules implemented manually with unit-scaled layers (i.e. without the global unit_scale(module) transform) can still use torch.compile() in the standard way.

Parameters:

module (M) – the module to be compiled.

Returns:

the compiled module.

Return type:

M