3.7.2. unit_scaling.transforms.utils.patch_to_expand_modules

unit_scaling.transforms.utils.patch_to_expand_modules(fn: Callable[[...], T]) Callable[[...], T][source]

By default TorchDynamo doesn’t recurse into torch.nn modules or torch.nn.functional functions when capturing the FX graph. Any function which is wrapped in torch._dynamo.optimise() (or torch.compile()) and is then passed to this function as fn will now automatically recurse into torch.nn modules or torch.nn.functional functions.

In practice, to use this with a torch.nn.Module the typical use case is to call module = torch._dynamo.optimize(backend)(module), followed by module.forward = patch_to_expand_modules(module.forward).

This should be used in conjunction with torch_nn_modules_to_user_modules()

Parameters:

fn (Callable[..., T]) – the function to be patched.

Returns:

the new version of fn with patching applied.

Return type:

Callable[…, T]