4.8.1. unit_scaling.utils.analyse_module
- unit_scaling.utils.analyse_module(module: ~torch.nn.modules.module.Module, inputs: ~torch.Tensor | ~typing.Tuple[~torch.Tensor, ...], backward: ~torch.Tensor | None = None, recurse_modules: bool = True, syntax_highlight: bool = True, autowrap_modules: ~typing.Tuple[module, ...] = (<module 'math' (built-in)>, <module 'einops' from '/home/runner/.local/lib/python3.10/site-packages/einops/__init__.py'>, <module 'unit_scaling.functional' from '/home/runner/work/unit-scaling/unit-scaling/unit_scaling/functional.py'>), autowrap_functions: ~typing.Tuple[~typing.Callable[[...], ~typing.Any], ...] = ()) str [source]
Given a nn.Module and dummy forward and backward tensors, generates code representing the module annotated with the scales (standard deviation) of each tensor in both forward and backward passes. Implemented using torch.fx.
- Parameters:
module (nn.Module) – the module to analyse.
inputs (Union[Tensor, Tuple[Tensor, ...]]) – fed into the forward pass for analysis.
backward (Tensor, optional) – fed into the output’s .backward() method for analysis. Defaults to None, equivalent to calling plain .backward().
recurse_modules (bool, optional) – toggles recursive behavour. Defaults to True.
syntax_highlight (bool, optional) – Defaults to True.
autowrap_modules (Tuple[ModuleType]) – defaults to (math, einops, U.functional), Python modules whose functions should be wrapped automatically without needing to use fx.wrap().
autowrap_function (Tuple[Callable, ...]) – defaults to (), Python functions that should be wrapped automatically without needing to use fx.wrap().
- Returns:
a code string representing the operations in the module with scale annotations for each tensor, reflecting their standard deviations in the forward and backward passes.
- Return type:
Examples:
>>> class MLP(nn.Module): >>> def __init__(self, d): >>> super().__init__() >>> self.fc1 = nn.Linear(d, d * 4) >>> self.relu = nn.ReLU() >>> self.fc2 = nn.Linear(d * 4, d) >>> def forward(self, x): >>> x = self.fc1(x) >>> x = self.relu(x) >>> x = self.fc2(x) >>> return x >>> hidden_size = 2**10 >>> x = torch.randn(hidden_size, hidden_size).requires_grad_() >>> bwd = torch.randn(hidden_size, hidden_size) >>> code = analyse_module(MLP(hidden_size), x, bwd) >>> print(code) def forward(self, x): (-> 1.0, <- 0.236) fc1_weight = self.fc1.weight; (-> 0.018, <- 6.54) fc1_bias = self.fc1.bias; (-> 0.0182, <- 6.51) linear = torch._C._nn.linear(x, fc1_weight, fc1_bias); (-> 0.578, <- 0.204) relu = torch.nn.functional.relu(linear, inplace = False); (-> 0.337, <- 0.288) fc2_weight = self.fc2.weight; (-> 0.00902, <- 13.0) fc2_bias = self.fc2.bias; (-> 0.00904, <- 31.6) linear_1 = torch._C._nn.linear(relu, fc2_weight, fc2_bias); (-> 0.235, <- 0.999) return linear_1