3.7.1. unit_scaling.transforms.utils.apply_transform

unit_scaling.transforms.utils.apply_transform(module: M, backend: Callable[[GraphModule, List[Tensor]], Callable[[...], Any]], non_recurse_functions: List[Callable[[...], Any]] = []) M[source]

Applies a graph transformation to a module.

The user-supplied backend represents a transformation of a torch.fx.graph_module.GraphModule. apply_transform() uses torch._dynamo.optimize() to apply this transformation to the module, returning a new transformed module.

Note that the current version of TorchDynamo (or torch.compile(), which is a wrapper around TorchDynamo) doesn’t support nested transforms, so we implement our own system here. This makes it easy to nest transforms:

module = apply_transform(apply_transform(module, backend_1), backend_2)

However, it should be noted these transforms are not interoperable with the standard torch.compile() interface.

This nesting system is implemented by moving the call to torch._dynamo.optimize() within the forward() method of the module (though it is only executed on the first call to the module, or if a new transform is applied, the optimised call being cached thereafter). This differs from the standard approach used with torch._dynamo.optimize(), but enables this convenient nesting functionality.

Parameters:
  • _module (nn.Module) – the module to be transformed.

  • backend (Backend) – the graph transformation to be applied.

  • non_recurse_functions (Iterable[Callable[..., Any]], optional) – functions which the user does not wish to be recursed into. Defaults to list().

Returns:

the transformed module.

Return type:

nn.Module