3.7.1. unit_scaling.transforms.utils.apply_transform
- unit_scaling.transforms.utils.apply_transform(module: M, backend: Callable[[GraphModule, List[Tensor]], Callable[[...], Any]], non_recurse_functions: List[Callable[[...], Any]] = []) M [source]
Applies a graph transformation to a module.
The user-supplied
backend
represents a transformation of atorch.fx.graph_module.GraphModule
.apply_transform()
usestorch._dynamo.optimize()
to apply this transformation to the module, returning a new transformed module.Note that the current version of TorchDynamo (or
torch.compile()
, which is a wrapper around TorchDynamo) doesn’t support nested transforms, so we implement our own system here. This makes it easy to nest transforms:module = apply_transform(apply_transform(module, backend_1), backend_2)
However, it should be noted these transforms are not interoperable with the standard
torch.compile()
interface.This nesting system is implemented by moving the call to
torch._dynamo.optimize()
within theforward()
method of the module (though it is only executed on the first call to the module, or if a new transform is applied, the optimised call being cached thereafter). This differs from the standard approach used withtorch._dynamo.optimize()
, but enables this convenient nesting functionality.- Parameters:
_module (nn.Module) – the module to be transformed.
backend (Backend) – the graph transformation to be applied.
non_recurse_functions (Iterable[Callable[..., Any]], optional) – functions which the user does not wish to be recursed into. Defaults to list().
- Returns:
the transformed module.
- Return type:
nn.Module