3.8.1. unit_scaling.utils.analyse_module

unit_scaling.utils.analyse_module(module: ~torch.nn.modules.module.Module, inputs: ~torch.Tensor | ~typing.Tuple[~torch.Tensor, ...], backward: ~torch.Tensor | None = None, recurse_modules: bool = True, syntax_highlight: bool = True, autowrap_modules: ~typing.Tuple[~types.ModuleType, ...] = (<module 'math' from '/opt/conda/lib/python3.10/lib-dynload/math.cpython-310-x86_64-linux-gnu.so'>, <module 'einops' from '/home/developer/.local/lib/python3.10/site-packages/einops/__init__.py'>, <module 'unit_scaling.functional' from '/home/developer/unit-scaling/unit_scaling/functional.py'>), autowrap_functions: ~typing.Tuple[~typing.Callable[[...], ~typing.Any], ...] = ()) str[source]

Given a nn.Module and dummy forward and backward tensors, generates code representing the module annotated with the scales (standard deviation) of each tensor in both forward and backward passes. Implemented using torch.fx.

Parameters:
  • module (nn.Module) – the module to analyse.

  • inputs (Union[Tensor, Tuple[Tensor, ...]]) – fed into the forward pass for analysis.

  • backward (Tensor, optional) – fed into the output’s .backward() method for analysis. Defaults to None, equivalent to calling plain .backward().

  • recurse_modules (bool, optional) – toggles recursive behavour. Defaults to True.

  • syntax_highlight (bool, optional) – Defaults to True.

  • autowrap_modules (Tuple[ModuleType]) – defaults to (math, einops, U.functional), Python modules whose functions should be wrapped automatically without needing to use fx.wrap().

  • autowrap_function (Tuple[Callable, ...]) – defaults to (), Python functions that should be wrapped automatically without needing to use fx.wrap().

Returns:

a code string representing the operations in the module with scale annotations for each tensor, reflecting their standard deviations in the forward and backward passes.

Return type:

str

Examples:

>>> class MLP(nn.Module):
>>>    def __init__(self, d):
>>>        super().__init__()
>>>        self.fc1 = nn.Linear(d, d * 4)
>>>        self.relu = nn.ReLU()
>>>        self.fc2 = nn.Linear(d * 4, d)

>>>    def forward(self, x):
>>>        x = self.fc1(x)
>>>        x = self.relu(x)
>>>        x = self.fc2(x)
>>>        return x


>>> hidden_size = 2**10
>>> x = torch.randn(hidden_size, hidden_size).requires_grad_()
>>> bwd = torch.randn(hidden_size, hidden_size)

>>> code = analyse_module(MLP(hidden_size), x, bwd)
>>> print(code)
def forward(self, x):  (-> 1.0, <- 0.236)
    fc1_weight = self.fc1.weight;  (-> 0.018, <- 6.54)
    fc1_bias = self.fc1.bias;  (-> 0.0182, <- 6.51)
    linear = torch._C._nn.linear(x, fc1_weight, fc1_bias);  (-> 0.578, <- 0.204)
    relu = torch.nn.functional.relu(linear, inplace = False);  (-> 0.337, <- 0.288)
    fc2_weight = self.fc2.weight;  (-> 0.00902, <- 13.0)
    fc2_bias = self.fc2.bias;  (-> 0.00904, <- 31.6)
    linear_1 = torch._C._nn.linear(relu, fc2_weight, fc2_bias);  (-> 0.235, <- 0.999)
    return linear_1