3.7.2. unit_scaling.transforms.utils.patch_to_expand_modules
- unit_scaling.transforms.utils.patch_to_expand_modules(fn: Callable[[...], T]) Callable[[...], T][source]
By default TorchDynamo doesn’t recurse into
torch.nnmodules ortorch.nn.functionalfunctions when capturing the FX graph. Any function which is wrapped intorch._dynamo.optimise()(ortorch.compile()) and is then passed to this function as fn will now automatically recurse intotorch.nnmodules ortorch.nn.functionalfunctions.In practice, to use this with a
torch.nn.Modulethe typical use case is to call module = torch._dynamo.optimize(backend)(module), followed by module.forward = patch_to_expand_modules(module.forward).This should be used in conjunction with
torch_nn_modules_to_user_modules()- Parameters:
fn (Callable[..., T]) – the function to be patched.
- Returns:
the new version of fn with patching applied.
- Return type:
Callable[…, T]