3.7.2. unit_scaling.transforms.utils.patch_to_expand_modules
- unit_scaling.transforms.utils.patch_to_expand_modules(fn: Callable[[...], T]) Callable[[...], T] [source]
By default TorchDynamo doesn’t recurse into
torch.nn
modules ortorch.nn.functional
functions when capturing the FX graph. Any function which is wrapped intorch._dynamo.optimise()
(ortorch.compile()
) and is then passed to this function as fn will now automatically recurse intotorch.nn
modules ortorch.nn.functional
functions.In practice, to use this with a
torch.nn.Module
the typical use case is to call module = torch._dynamo.optimize(backend)(module), followed by module.forward = patch_to_expand_modules(module.forward).This should be used in conjunction with
torch_nn_modules_to_user_modules()
- Parameters:
fn (Callable[..., T]) – the function to be patched.
- Returns:
the new version of fn with patching applied.
- Return type:
Callable[…, T]