Source code for unit_scaling.transforms._unit_scale

# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

import logging
from inspect import signature
from operator import getitem
from types import BuiltinFunctionType
from typing import Any, Callable, Dict, List, Set, TypeVar

import torch
import torch._dynamo
import torch.nn.functional as F
from torch import Tensor, nn
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node

from .. import functional as U
from .._internal_utils import generate__all__
from .utils import Backend, apply_transform, replace_node_with_function

T = TypeVar("T")
logger = logging.getLogger(__name__)


def _add_dependency_meta(graph: Graph, recalculate: bool = False) -> None:
    def recurse(n: Node) -> Set[Node]:
        if "dependencies" in n.meta:
            return n.meta["dependencies"]  # type: ignore[no-any-return]
        deps = set(n.all_input_nodes)
        for parent in n.all_input_nodes:
            deps.update(recurse(parent))
        n.meta["dependencies"] = deps
        return deps

    if recalculate:
        for n in graph.nodes:
            if "dependencies" in n.meta:
                del n.meta["dependencies"]

    for n in graph.nodes:
        if n.op == "output":
            recurse(n)


def _is_add(n: Node) -> bool:
    return (
        n.op == "call_function"
        and isinstance(n.target, BuiltinFunctionType)
        and n.target.__name__ in ["add", "iadd"]
    )


def _is_self_attention(skip_node: Node, residual_node: Node) -> bool:
    # Identify all operations on the residual branch
    residual_fns = {residual_node.target}
    parents = [*residual_node.all_input_nodes]
    while parents:
        p = parents.pop()
        if p == skip_node:
            continue
        residual_fns.add(p.target)
        parents += p.all_input_nodes
    # Check if any of them are self-attention ops (softmax or fused self-attention)
    self_attention_fns = {
        F.scaled_dot_product_attention,
        U.scaled_dot_product_attention,
        F.softmax,
        U.softmax,
    }
    return bool(residual_fns.intersection(self_attention_fns))


def _unit_scale_residual(
    graph: Graph,
    add: Node,
    residual_arg_idx: int,
    is_self_attention: bool,
) -> None:
    residual, skip = add.args[residual_arg_idx], add.args[1 - residual_arg_idx]
    tau = 0.01 if is_self_attention else 0.5
    logger.info("unit scaling function: %s (residual-add, tau=%s)", add, tau)
    old_start_residuals = [
        u for u in skip.users if u is not add  # type: ignore[union-attr]
    ]
    with graph.inserting_after(skip):  # type: ignore[arg-type]
        split = graph.call_function(
            U.residual_split,
            args=(skip, tau),
            type_expr=getattr(skip, "type", None),
        )
    with graph.inserting_after(split):
        new_start_residual = graph.call_function(getitem, args=(split, 0))
    for old_start_residual in old_start_residuals:
        old_start_residual.replace_input_with(skip, new_start_residual)  # type: ignore
    with graph.inserting_after(split):
        new_skip = graph.call_function(getitem, args=(split, 1))
    replace_node_with_function(
        graph, add, U.residual_add, args=(residual, new_skip, tau)
    )


def _unconstrain_node(node: Node) -> None:
    if (
        node.op == "call_function"
        and callable(node.target)
        and not isinstance(node.target, BuiltinFunctionType)
        and "constraint" in signature(node.target).parameters
    ):
        logger.info("unconstraining node: %s", node)
        node.kwargs = dict(node.kwargs, constraint=None)


def unit_scaling_backend(
    replacement_map: Dict[Callable[..., Any], Callable[..., Any]] = dict()
) -> Backend:
    def inner_backend(gm: GraphModule, example_inputs: List[Tensor]) -> GraphModule:
        logger.info("running unit scaling backend")
        graph = gm.graph
        # Replace function nodes with those in `replacement_map` or with their
        # unit scaled equivalents
        for node in graph.nodes:
            if node.op == "call_function":
                if node.target in replacement_map:
                    target_fn = replacement_map[node.target]
                    logger.info(
                        "replacing function: %s with %s", node, target_fn.__name__
                    )
                    replace_node_with_function(graph, node, target_fn)
                elif node.target in U.torch_map:
                    target_fn = U.torch_map[node.target]
                    logger.info("unit scaling function: %s", node)
                    replace_node_with_function(graph, node, target_fn)

        # Add metadata denoting the dependencies of every node in the graph
        _add_dependency_meta(graph)

        # Go through and mark nodes which represent residual-adds
        residual_layer_number = 1
        for node in graph.nodes:
            if _is_add(node):
                is_residual_add = False
                if len(node.args) == 2:
                    l, r = node.args
                    if isinstance(l, Node) and isinstance(r, Node):
                        l_deps = l.meta.get("dependencies", list())
                        r_deps = r.meta.get("dependencies", list())
                        if l in r_deps or r in l_deps:
                            node.meta["residual_add"] = {
                                "residual_arg_idx": 1 if l in r_deps else 0,
                            }
                            residual_layer_number += 1
                            is_residual_add = True
                            skip_node, residual_node = (l, r) if l in r_deps else (r, l)
                            is_sa = _is_self_attention(skip_node, residual_node)
                            node.meta["residual_add"]["is_self_attention"] = is_sa
                # Regular adds are not picked up by the unit scaling sweep above as
                # the inbuilt + operation is handled differently when traced. It is
                # instead substituted for its unit scaled equivalent here.
                if not is_residual_add:
                    logger.info("unit scaling function: %s", node)
                    args = (*node.args, None)  # None denotes unconstrained
                    replace_node_with_function(graph, node, U.add, args=args)

        # Replace nodes marked as residual-adds with unit scaled equivalent
        for node in graph.nodes:
            residual_add = node.meta.get("residual_add", None)
            if residual_add is not None:
                _unit_scale_residual(graph, node, **residual_add)

        _add_dependency_meta(graph, recalculate=True)
        for node in graph.nodes:
            if node.target == U.residual_add:
                node.meta["has_residual_successor"] = True
                dependencies = node.meta.get("dependencies", [])
                for d in dependencies:
                    d.meta["has_residual_successor"] = True

        for node in graph.nodes:
            if "has_residual_successor" not in node.meta:
                _unconstrain_node(node)

        graph.lint()  # type: ignore[no-untyped-call]
        return GraphModule(gm, graph)

    return inner_backend


def _unit_init_weights(m: nn.Module) -> None:
    for name, mod in m.named_modules():
        if isinstance(mod, (nn.Linear, nn.Embedding)):
            with torch.no_grad():
                if isinstance(mod.weight, Tensor):
                    logger.info("unit scaling weight: %s.weight", name)
                    mod.weight /= mod.weight.std()


def _zero_init_biases(m: nn.Module) -> None:
    for name, mod in m.named_modules():
        if isinstance(mod, (nn.Linear, nn.Embedding)):
            with torch.no_grad():
                if hasattr(mod, "bias") and isinstance(mod.bias, Tensor):
                    logger.info("setting bias to zero: %s.bias", name)
                    mod.bias -= mod.bias


# Unit scaling should come before quantisation
def _order_backends(backends: List[Backend]) -> None:
    unit_scaling_backend_idx = -1
    quantisation_backend_idx = float("inf")
    for i, b in enumerate(backends):
        if "unit_scaling_backend" in b.__qualname__:
            unit_scaling_backend_idx = i
        if "quantisation_backend" in b.__qualname__:
            quantisation_backend_idx = i
    if unit_scaling_backend_idx > quantisation_backend_idx:
        logger.info("moving unit scaling backend to precede quantisation backend")
        u = backends.pop(unit_scaling_backend_idx)
        backends.insert(quantisation_backend_idx, u)  # type: ignore[arg-type]


M = TypeVar("M", bound=nn.Module)


[docs] def unit_scale( module: M, replace: Dict[Callable[..., Any], Callable[..., Any]] = dict() ) -> M: """**[Experimental]** Returns a unit-scaled version of the input model. Uses TorchDynamo to trace and transform the user-supplied module. This transformation identifies all :class:`torch.nn.functional` uses in the input module, and replaces them with their unit-scaled equivalents, should they exist. The tracing procedure automatically recurses into modules (whether defined in libraries, or by the user), identifying inner calls to any :class:`torch.nn.functional` operations, to build a graph of fundamental operations. Unit scaling is then applied as a transformation on top of this graph. This transformation proceeds in five stages: #. **Replacement of user-defined functions** according to the supplied `replace` dictionary. #. **Replacement of all functions with unit-scaled equivalents** defined in :mod:`unit_scaling.functional`. #. Identification & **replacement of all add operations that represent residual-adds**. The identification of residual connections is done via a dependency analysis on the graph. Residual-adds require special scaling compared with regular adds (see paper / User Guide for details). #. **Unconstraining of all operations after the final residual layer**. By default all unit scaled operations have their scaling factors constrained in the forward and backward pass to give valid gradients. This is not required in these final layers (see paper for proof), and hence we can unconstrain the operations to give better scaling. #. **Unit-scaling of all weights** and zero-initialisation of all biases. Note that by using TorchDynamo, `unit_scale()` is able to trace a much larger set of modules / operations than with previous PyTorch tracing approaches. This enables the process of unit scaling to be expressed as a generic graph transform that can be applied to arbitrary modules. Note that the current version of TorchDynamo (or :func:`torch.compile`, which is a wrapper around TorchDynamo) doesn't support nested transforms, so we implement our own system here. This makes it easy to nest transforms: .. code-block:: python from unit_scaling.transforms import compile, simulate_fp8, unit_scale module = compile(simulate_fp8(unit_scale(module))) However, these transforms are not interoperable with the standard :func:`torch.compile` interface. In some cases users may have a model definition that uses a custom implementation of a basic operation. In this case, `unit_scale()` can be told explicitly to substitute the layer for an equivalent, using the `replace` dictionary: .. code-block:: python import unit_scaling.functional as U from unit_scaling.transforms import unit_scale def new_gelu(x): ... class Model(nn.Module): def forward(x): ... x = new_gelu(x) ... model = unit_scale(Model(), replace={new_gelu: U.gelu}) This can also be used to substitute a particular function for a user-defined unit-scaled function not provided by :mod:`unit_scaling.functional`. **Note:** `unit_scale()` is experimental and has not yet been widely tested on a range of models. The standard approach to unit scaling a model is still to manually substitute the layers/operations in a model with their unit-scaled equivalents. Having said this, `unit_scale()` is implemented in a sufficiently generic way that we anticipate many users will ultimately be able to rely on this graph transform alone. Args: module (nn.Module): the input module to be unit scaled. replace (Dict[Callable, Callable], optional): a dictionary where keys represent functions to be replaced by the corresponding value-functions. Note that these substitutions take priority over the standard unit scaling substitutions. Defaults to dict(). Returns: nn.Module: the unit scaled module (with an independent copy of parameters) """ unit_scaled_module = apply_transform( module, unit_scaling_backend(replace), non_recurse_functions=list(replace.keys()), ) _order_backends(unit_scaled_module.backends) _unit_init_weights(unit_scaled_module) _zero_init_biases(unit_scaled_module) return unit_scaled_module # type: ignore[no-any-return]
__all__ = generate__all__(__name__)