Source code for unit_scaling.transforms._track_scales

# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

import logging
from copy import deepcopy
from dataclasses import dataclass
from math import isclose
from types import MethodType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar

import torch
from tabulate import tabulate
from torch import Tensor, nn
from torch.fx import Interpreter
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, Target

from .utils import apply_transform

logger = logging.getLogger(__name__)
M = TypeVar("M", bound=nn.Module)


[docs] class Metrics: """A set of metrics representing useful information about a tensor, in the forward and backward pass. :code:`metrics.fwd`, and optionally :code:`metrics.bwd`, are objects of type :code:`Data`, containing metrics about the tensor in the forward pass and its gradient in the backward pass. """
[docs] @dataclass class Data: """Data representing key metrics for a tensor.""" mean_abs: float abs_mean: float std: float abs_max: float abs_min: float numel: int
def __init__(self, fwd_tensor: Tensor) -> None: self.fwd = self.from_tensor(fwd_tensor) self.bwd: Optional[Metrics.Data] = None def set_bwd(self, bwd_tensor: Tensor) -> None: self.bwd = self.from_tensor(bwd_tensor) def __str__(self) -> str: return f"fwd: {self.fwd}, bwd: {self.bwd}" # pragma: no cover @staticmethod def from_tensor(t: Tensor) -> "Metrics.Data": abs_t = t.abs() return Metrics.Data( mean_abs=abs_t.mean().item(), abs_mean=t.mean().abs().item(), std=t.std().item(), abs_max=abs_t.max().item(), abs_min=abs_t.min().item(), numel=t.numel(), ) @staticmethod def get_full_name(short_name: str) -> str: return { "mean_abs": "mean absolute value", "abs_mean": "absolute mean value", "std": "standard deviation", "abs_max": "absolute maximum", "abs_min": "absolute minimum", "numel": "number of elements", }[short_name] @staticmethod def names() -> List[str]: return list(Metrics.Data.__dataclass_fields__.keys()) @staticmethod def full_names() -> List[str]: return [Metrics.get_full_name(m) for m in Metrics.names()]
def _directions_same_scale( a: Metrics.Data, b: Metrics.Data, rtol: float = 2**-16 ) -> bool: # We're just going to look at the mean_abs here return isclose(a.mean_abs, b.mean_abs, rel_tol=rtol) def _metrics_same_scale(a: Metrics, b: Metrics, rtol: float = 2**-16) -> bool: # Both are None if a.bwd is None and b.bwd is None: return _directions_same_scale(a.fwd, b.fwd, rtol) # Only one is None if a.bwd is None or b.bwd is None: # pragma: no cover return False # Neither is None return _directions_same_scale(a.fwd, b.fwd, rtol) and _directions_same_scale( a.bwd, b.bwd, rtol ) def _make_input_tensors_require_grad(module: nn.Module) -> None: old_forward = module.forward def new_forward(*args: Any, **kwargs: Any) -> Any: for a in args + tuple(kwargs.values()): if _is_float_tensor(a): a.requires_grad_() return old_forward(*args, **kwargs) module.forward = new_forward class ScaleTrackingAutogradFunction(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, t: Tensor, node_meta: Dict[str, Any], ) -> Tensor: node_meta["metrics"] = Metrics(fwd_tensor=t) ctx.node_meta = node_meta # type: ignore return t.clone() @staticmethod def backward( # type:ignore[override] ctx: torch.autograd.function.FunctionCtx, t: Tensor ) -> Tuple[Tensor, None, None]: ctx.node_meta["metrics"].set_bwd(bwd_tensor=t) # type: ignore return t.clone(), None, None def _tabulate_graph_data(g: Graph) -> str: # pragma: no cover meta_ignore_fields = [ "creation_timestamp", "stack_trace", "example_value", "tensor_dict", "grapharg", ] headers = ["opcode", "name", "target", "args", "kwargs", "users", "input_nodes"] meta_headers = [] for n in g.nodes: for k in n.meta: if k not in meta_headers and k not in meta_ignore_fields: meta_headers.append(k) data = [] for n in g.nodes: node_data = [ n.op, n.name, n.target, n.args, n.kwargs, list(n.users.keys()), list(n._input_nodes.keys()), ] + [n.meta.get(h) for h in meta_headers] data.append(node_data) return tabulate(data, headers=headers + meta_headers, tablefmt="html") def _add_tabular_html_display(g: Graph) -> None: """When `_repr_html_()` is called on `g` (e.g. if it's displayed in a Jupyter cell), return a tabulated view of the graph's data, including relevant metadata.""" g._repr_html_ = MethodType(_tabulate_graph_data, g) # type: ignore[attr-defined] def _clean_node_name(name: str) -> str: replace_strs = ["l__self___", "L__self___", "l_", "L_"] for r in replace_strs: name = name.replace(r, "") return name.strip("_") def _is_float_tensor(a: Any) -> bool: return isinstance(a, Tensor) and a.is_floating_point() def _get_tracking_meta(n: Node, out: Any) -> Dict[str, Any]: return { "clean_name": _clean_node_name(n.name), "outputs_float_tensor": _is_float_tensor(out), "requires_grad": isinstance(out, nn.Parameter) and out.requires_grad, } class ScaleTrackingInterpreter(Interpreter): def __init__(self, gm: GraphModule) -> None: super().__init__(gm) def run_node(self, n: Node) -> Any: out = super().run_node(n) n.meta.update(_get_tracking_meta(n, out)) if n.meta["outputs_float_tensor"]: logger.info("adding tracking to node: %s", n) out = ScaleTrackingAutogradFunction.apply(out, n.meta) # type: ignore return out def __call__(self, *args: Any, **kwargs: Any) -> Any: return super().run(*args, **kwargs) class ScaleTrackingBackend: def __init__(self) -> None: self.graph = Graph() def __call__( self, gm: GraphModule, example_inputs: List[Tensor] ) -> Callable[..., Any]: _add_tabular_html_display(gm.graph) # displays full info in notebooks self.graph = gm.graph # allows graph to be accessed from outside return ScaleTrackingInterpreter(gm) def _prune(graph: Graph, node: Node, replacement_arg: Optional[Node] = None) -> None: for user in list(node.users): # output node's args are a tuple of tuples, so need special handling if user.name == "output": user.args = tuple( ( (tuple(replacement_arg if o == node else o for o in out)) if isinstance(out, Iterable) else (replacement_arg if out == node else out) ) for out in user.args ) else: user.args = tuple(replacement_arg if a == node else a for a in user.args) user.kwargs = { k: replacement_arg if v == node else v for k, v in user.kwargs.items() } graph.erase_node(node) def _filter_float_tensors(args: List[Any]) -> List[Node]: return [ a for a in args if isinstance(a, Node) and a.meta.get("outputs_float_tensor", False) ]
[docs] def track_scales(module: M) -> M: """Returns a version of the input module which tracks internal tensor metrics. When the :code:`forward()` and :code:`backward()` methods of the resulting module are called, internally various metrics (such as scale) are recorded for each intermediate tensor used. These can be accessed using an additional method :code:`module.scales_graph()` which is added to the module. The returned object is an instance of :external+pytorch:py:class:`torch.fx.Graph`, where each node representing a floating-point tensor now has a :code:`node.meta["metrics"]` object of type :class:`unit_scaling.transforms.Metrics` associated with it. Note that if :code:`forward()` or :code:`backward()` are not called, tensor metrics will not be available. The unit scaling library also provides a method to visualise FX Graphs, via the :func:`unit_scaling.analysis.plot` function. This is intended to be used as follows: .. code-block:: python from unit_scaling.transforms import track_scales from unit_scaling.analysis import plot inpt = ... model = ... model = track_scales(model) loss = model(inpt) loss.backward() graph = model.scales_graph() plot(graph) The :code:`inpt` tensor(s) provided to any model transformed by :code:`track_scales()` will automatically have :code:`inpt.requires_grad_()` set (this is required for full scale tracking in the backward pass), so the user need not specify this. :code:`track_scales()` can be used in conjunction with other graph transforms provided, but should always be the final transform in a chain. E.g. .. code-block:: python from unit_scaling.transforms import simulate_fp8, track_scales, unit_scale model = track_scales(unit_scale(simulate_fp8(model))) The full FX graph returned by this transform may contain more information than the user requires for the sake of analysis. For this reason the functions :func:`unit_scaling.transforms.prune_non_float_tensors` and :func:`unit_scaling.transforms.prune_same_scale_tensors` are provided, which in practice tend to limit the graph to only key tensors. Args: module (M): the input module to be tracked. Returns: M: a new version of the input module which tracks tensor metrics when used. """ backend = ScaleTrackingBackend() tracking_module = apply_transform(module, backend) tracking_module.scales_graph = lambda: backend.graph _make_input_tensors_require_grad(tracking_module) return tracking_module # type: ignore[no-any-return]
[docs] def prune_non_float_tensors(graph: Graph) -> Graph: """Given an FX Graph, prunes all nodes which don't output floating-point tensors. The supplied graph must have been generated via the :code:`module.scales_graph()` method, called on a module with :func:`unit_scaling.transforms.track_scales` applied. This is necessary as the scale-tracking process is what identifies which tensors have floating-point values. E.g. .. code-block:: python from unit_scaling.transforms import track_scales, prune_non_float_tensors inpt = ... model = ... model = track_scales(model) loss = model(inpt) loss.backward() graph = model.scales_graph() pruned_graph = prune_non_float_tensors(graph) Args: graph (Graph): the FX graph to be pruned. Returns: Graph: the pruned graph containing only nodes outputting floating-point tensors. """ if "clean_name" not in list(graph.nodes)[0].meta: # pragma: no cover raise RuntimeError( "supplied graph must be a result of running" " `unit_scaling.transforms.track_scales` (and extracting" " the `scales_graph`). Arbitrary fx graphs lack metadata" " about the types of node outputs so cannot be pruned." ) graph = deepcopy(graph) for n in graph.nodes: if n.name == "output": continue if not n.meta.get("outputs_float_tensor", False): float_tensor_args = _filter_float_tensors(n.args) a = float_tensor_args[0] if len(float_tensor_args) == 1 else None logger.info("pruning non-float node: %s", n) _prune(graph, n, replacement_arg=a) graph.lint() # type: ignore[no-untyped-call] return graph
[docs] def prune_same_scale_tensors(graph: Graph, rtol: float = 2**-16) -> Graph: """Given an FX Graph, prunes all nodes with the same scale as the previous node. This is intended to remove non-informative nodes from the graph such as reshapes. Nodes with multiple floating-point tensors as inputs are never pruned. Certain operations (such as slices) may change the scale slightly, but negligibly—in this case we provide a tolerance parameter which can be used to specify the relative change that is deemed significant. The supplied graph must have been generated via the :code:`module.scales_graph()` method, called on a module with :func:`unit_scaling.transforms.track_scales` applied. E.g. .. code-block:: python from unit_scaling.transforms import track_scales, prune_same_scale_tensors inpt = ... model = ... model = track_scales(model) loss = model(inpt) loss.backward() graph = model.scales_graph() pruned_graph = prune_same_scale_tensors(graph) Args: graph (Graph): the FX graph to be pruned. rtol (float, optional): the relative tolerance for "same scale". Defaults to 2**-16. Returns: Graph: the pruned graph with nodes that don't change their input scale removed. """ graph = deepcopy(graph) for n in graph.nodes: if n.name == "output" or not n.meta.get("outputs_float_tensor", False): continue float_tensor_args = _filter_float_tensors(n.args) if len(float_tensor_args) == 1: a = float_tensor_args[0] a_metrics = a.meta["metrics"] n_metrics = n.meta["metrics"] if _metrics_same_scale(n_metrics, a_metrics, rtol): logger.info("pruning same-scale node: %s", n) _prune(graph, n, replacement_arg=a) return graph
[docs] def prune_selected_nodes(graph: Graph, targets: Iterable[Target]) -> Graph: """Given an FX Graph, prunes all nodes with functions in the set of target nodes. Args: graph (Graph): the FX graph to prune. targets (Iterable[Target]): the functions which will not be represented by nodes once pruning is complete. Returns: Graph: the pruned graph. """ for n in graph.nodes: if n.target in targets: logger.info("pruning node: %s", n) _prune(graph, n) graph.lint() # type: ignore[no-untyped-call] return graph