# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""Tools for analysing scale (and other metrics) within PyTorch models."""
import colorsys
import logging
import re
from math import isnan
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
import matplotlib
import matplotlib.colors
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns # type: ignore[import-untyped]
from datasets import load_dataset # type: ignore[import-untyped]
from torch import Tensor, nn
from torch.fx.graph import Graph
from torch.fx.node import Node
from ._internal_utils import generate__all__
from .transforms import (
Metrics,
prune_non_float_tensors,
prune_same_scale_tensors,
track_scales,
)
if TYPE_CHECKING: # pragma: no cover
from transformers.tokenization_utils_base import ( # type: ignore
PreTrainedTokenizerBase,
)
logger = logging.getLogger(__name__)
def _example_seqs(
batch_size: int,
min_seq_len: int,
dataset_path: str = "wikitext",
dataset_name: str = "wikitext-103-v1",
shuffle_buffer_size: int = 10_000,
seed: int = 1472,
) -> List[str]:
dataset = load_dataset(dataset_path, dataset_name, split="test", streaming=True)
shuffled_dataset = dataset.shuffle(seed=seed, buffer_size=shuffle_buffer_size)
filtered_dataset = shuffled_dataset.filter(lambda d: len(d["text"]) > min_seq_len)
batch = filtered_dataset.take(batch_size)
return [d["text"] for d in batch]
def _create_batch(
tokenizer: "PreTrainedTokenizerBase",
seqs: List[str],
seq_len: int,
) -> Tuple[Tensor, Tensor, Tensor]:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
out = tokenizer(
seqs, max_length=seq_len + 1, truncation=True, return_tensors="pt", padding=True
)
input_idxs = out["input_ids"][:, :seq_len].clone()
attn_mask = out["attention_mask"][:, :seq_len].clone()
labels = out["input_ids"][:, 1 : seq_len + 1].clone()
return input_idxs, attn_mask, labels
[docs]
def example_batch(
tokenizer: "PreTrainedTokenizerBase",
batch_size: int,
seq_len: int,
dataset_path: str = "wikitext",
dataset_name: str = "wikitext-103-v1",
shuffle_buffer_size: int = 10_000,
seed: int = 1472,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Generates a batch of token IDs from a given dataset, along with an attention mask
and labels (just the shifted token IDs).
Args:
tokenizer (PreTrainedTokenizerBase): the tokenizer applied to the text data.
batch_size (int): the batch size of the returned tensor.
seq_len (int): the sequence length (number of IDs) of the returned tensor.
dataset_path (str, optional): huggingface path of the dataset to use for
visualisation. Defaults to "wikitext".
dataset_name (str, optional): huggingface name of the dataset to use for
visualisation. Defaults to "wikitext-103-v1".
shuffle_buffer_size (int, optional): the tokenized data is a random sample from
a chunk of the full dataset. This determines the chunk size.
Defaults to 10_000.
seed (int, optional): shuffle seed. Defaults to 1472.
Returns:
Tuple[Tensor]: a tuple of (input_idxs, attn_mask, labels)
"""
seqs = _example_seqs(
batch_size, seq_len * 4, dataset_path, dataset_name, shuffle_buffer_size, seed
)
return _create_batch(
tokenizer,
seqs,
seq_len,
)
[docs]
def graph_to_dataframe(g: Graph) -> pd.DataFrame:
"""Converts a :class:`torch.fx.Graph` with annotated
:class:`unit_scaling.transforms.Metrics` into a :class:`pandas.DataFrame`.
This graph is indended to have been generated by applying
:func:`unit_scaling.transforms.track_scales` to an arbitrary
:class:`torch.nn.Module`, running a forward (and possibly backward) pass,
then calling the `module.scales_graph()` function.
The resulting dataframe contains all the metrics information for the module,
and is used internally by the :func:`unit_scaling.analysis.plot` function.
Args:
g (Graph): the input graph.
Returns:
pd.DataFrame: the metrics dataframe.
"""
columns = [
"layer",
"weight tensor",
"direction",
"tensor type",
] + Metrics.full_names()
data = []
for n in g.nodes:
# 'output' has to be kept from previous stages to keep fx happy. We drop it here
if n.name == "output":
continue
for direction in ["fwd", "bwd"]:
tensor_type_prefix = "" if direction == "fwd" else "grad_"
tensor_type_suffix = "w" if n.meta["requires_grad"] else "x"
row_data = [
n.meta["clean_name"],
n.meta["requires_grad"],
direction,
tensor_type_prefix + tensor_type_suffix,
]
for m in Metrics.names():
directional_metrics = getattr(n.meta["metrics"], direction, None)
if directional_metrics is not None:
v = getattr(directional_metrics, m)
else:
v = None # pragma: no cover
row_data.append(v)
data.append(row_data)
return pd.DataFrame.from_dict(
{i: row for i, row in enumerate(data)},
orient="index",
columns=columns,
)
[docs]
def plot(
g: Graph,
title: str = "",
metric: str = "mean_abs",
prune_same_scale: bool = True,
show_arrows: bool = True,
show_error_bars: bool = True,
show_zero_tensors: bool = False,
xmin: Optional[float] = None,
xmax: Optional[float] = None,
) -> matplotlib.axes.Axes:
"""Generate a :mod:`matplotlib` plot visualising the scales in the forward (and
optionally backward) pass of all tensors in an FX graph.
The input graph is intended to have been generated by applying
:func:`unit_scaling.transforms.track_scales` to an arbitrary
:class:`torch.nn.Module`, running a forward (and possibly backward) pass,
then calling the `module.scales_graph()` function:
.. code-block:: python
from unit_scaling.transforms import track_scales
from unit_scaling.analysis import plot
inpt = ...
model = ...
model = track_scales(model)
loss = model(inpt)
loss.backward()
graph = model.scales_graph()
plot(graph)
Operations that don't output floating-point tensors are automatically pruned from
the visualised graph, as they are deemed unlikely to be relevant from the
perspective of model numerics.
Faint coloured horizontal lines for each row represent error bars indicating
the maximum and minimum values seen in each tensor during tracking.
Args:
g (Graph): the graph to visualise.
title (str, optional): title for the generated plot. Defaults to "".
metric (str, optional): the metric to show on the x-axis. Can be any of:
("mean_abs", "abs_mean", "std", "abs_max", "abs_min", "numel").
Defaults to "mean_abs".
prune_same_scale (bool, optional): prune operations that don't change the scale
of their input tensors. In practice this means that views / reshapes are not
shown, making the resulting visualisation clearer. Defaults to True.
show_arrows (bool, optional): show arrows between operations,
denoting dependencies. Defaults to True.
show_error_bars (bool, optional): show max/min error bars. Defaults to True.
xmin (Optional[float], optional): the minimum x-value to display.
Defaults to None.
xmax (Optional[float], optional): the maximum x-value to display.
Defaults to None.
Returns:
matplotlib.axes.Axes: the axes representing the generated plot.
"""
assert metric in Metrics.names() + Metrics.full_names(), (
f"metric '{metric}' must be one of {Metrics.names()} (these correspond to"
f" {Metrics.full_names()})"
)
full_metric = Metrics.get_full_name(metric)
g = prune_non_float_tensors(g)
if prune_same_scale:
g = prune_same_scale_tensors(g)
df = graph_to_dataframe(g)
plot_height = len(df["layer"].unique())
plt.figure(figsize=(10, plot_height / 4))
colors = sns.color_palette("colorblind")
sns.set_palette(colors)
sns.set_theme()
p = sns.lineplot(
data=df,
x=full_metric,
y="layer",
hue="direction",
hue_order=["fwd", "bwd"],
style="weight tensor",
style_order=[False, True],
dashes=[(0, 1), (0, 1)],
markers=[".", "v"],
markersize=9,
estimator=None,
orient="y",
)
p.set_ylim(plot_height, -1)
plt.xscale("log", base=2)
p.xaxis.set_ticks_position("top")
p.xaxis.set_label_position("top")
p.xaxis.grid(False)
if title:
p.set_title(title, fontweight="bold")
label_map = {
"fwd": "forward pass",
"bwd": "backward pass",
"False": "non-weight tensor",
"True": "weight tensor",
}
new_legend_labels = {
label_map[l]: h
for h, l in zip(*p.get_legend_handles_labels())
if l in label_map
}
p.legend(
new_legend_labels.values(), new_legend_labels.keys(), loc="upper right"
).set_title("")
def _rename(s: str) -> str:
s = re.sub(r"(^|_)\d+", "", s)
s = s.replace("self_", "")
s = s.replace("transformer_h_", "")
s = s.replace("transformer_", "")
return s
plt.axvline(2**-14, color="grey", dashes=(3, 1))
plt.axvline(2**-7, color="grey", dashes=(1, 3))
plt.axvline(240, color="grey", dashes=(1, 3))
plt.axvline(2**16, color="grey", dashes=(3, 1))
plt.text(
2**-14,
plot_height + 0.2,
"FP16 min,\nFP8 E5 min\n(normal)",
ha="center",
va="top",
size=9,
)
plt.text(
2**-7,
plot_height + 0.2,
"FP8 E4 min\n(normal)",
ha="center",
va="top",
size=9,
)
plt.text(
240,
plot_height + 0.2,
"FP8 E4 max",
ha="center",
va="top",
size=9,
)
plt.text(
2**16,
plot_height + 0.2,
"FP16 max,\nFP8 E5 max",
ha="center",
va="top",
size=9,
)
# Cycle through the graph's nodes and give each an index (for the y-axis)
i = 0
node_idxs = {}
for node in g.nodes:
if node.name != "output":
name = node.meta["clean_name"]
if name not in node_idxs:
node_idxs[name] = i
i += 1
min_scale, max_scale = plt.gca().get_xlim()
if xmin is not None:
min_scale = xmin
if xmax is not None:
max_scale = xmax
def lighten_color(
color: Tuple[float, float, float], l_degree: float, s_degree: float
) -> Tuple[float, float, float]:
r, g, b = matplotlib.colors.to_rgb(color)
h, l, s = colorsys.rgb_to_hls(r, g, b)
new_l = 1 - l_degree * (1 - l)
new_s = s_degree * s
return colorsys.hls_to_rgb(h, new_l, new_s)
light_colors = [lighten_color(c, l_degree=0.35, s_degree=0.45) for c in colors]
def draw_error_bar(node: Node, direction: str) -> None:
metrics = node.meta["metrics"]
if direction == "bwd" and metrics.bwd is None: # pragma: no cover
return
directional_metrics = getattr(metrics, direction)
x1, x2 = directional_metrics.abs_min, directional_metrics.abs_max
y = node_idxs[node.meta["clean_name"]] + (-0.1 if direction == "fwd" else 0.1)
color = light_colors[0 if direction == "fwd" else 1]
plt.plot(
[x1, x2],
[y, y],
color=color,
linestyle="-",
linewidth=1,
marker="",
zorder=1,
)
for x in [x1, x2]:
plt.plot(
[x, x],
[y - 0.2, y + 0.2],
color=color,
linestyle="-",
linewidth=1,
marker="",
zorder=1,
)
plt.gca().set_xlim(min_scale, max_scale)
def draw_arrow(node_a: Node, node_b: Node, direction: str) -> None:
a_metrics = node_a.meta["metrics"]
b_metrics = node_b.meta["metrics"]
if direction == "bwd" and ( # pragma: no cover
a_metrics.bwd is None or b_metrics.bwd is None
):
return # pragma: no cover
a_x = getattr(getattr(a_metrics, direction), metric)
b_x = getattr(getattr(b_metrics, direction), metric)
a_y = node_idxs[node_a.meta["clean_name"]]
b_y = node_idxs[node_b.meta["clean_name"]]
annotation = ""
if a_x == 0 or isnan(a_x): # pragma: no cover
a_x = min_scale
if isnan(a_x): # pragma: no cover
logging.warning(f"Node '{node_a.meta['clean_name']}' is NaN. Plotting as 0")
a_x = min_scale
if b_x == 0: # pragma: no cover
b_x = min_scale
annotation = "0"
if isnan(b_x): # pragma: no cover
logging.warning(f"Node '{node_b.meta['clean_name']}' is NaN. Plotting as 0")
b_x = min_scale
annotation = "0"
if direction == "fwd":
color = colors[0]
else:
assert direction == "bwd", direction
color = colors[1]
a_x, a_y, b_x, b_y = b_x, b_y, a_x, a_y
if annotation == "0" and not show_zero_tensors:
return
plt.annotate(
annotation,
color=color,
va="center",
xy=((a_x, a_y)),
xytext=((b_x, b_y)),
arrowprops=dict(arrowstyle="->", color=color),
)
if show_arrows:
for n in g.nodes:
if n.name != "output":
for direction in ["fwd", "bwd"]:
for arg in n.args:
if isinstance(arg, Node):
draw_arrow(n, arg, direction)
if show_error_bars:
for n in g.nodes:
if n.name != "output":
for direction in ["fwd", "bwd"]:
draw_error_bar(n, direction)
p.set_yticks(p.get_yticks())
p.set_yticklabels([_rename(item.get_text()) for item in p.get_yticklabels()])
return p # type: ignore[no-any-return]
[docs]
def visualiser(
model: nn.Module,
tokenizer: "PreTrainedTokenizerBase",
batch_size: int,
seq_len: int,
backward: bool = True,
dataset_path: str = "wikitext",
dataset_name: str = "wikitext-103-v1",
**plot_kwargs: Any,
) -> matplotlib.axes.Axes:
"""[Experimental] Generate a plot visualising the scales in the forward (and
optionally backward) pass of all tensors in an arbitrary :class:`torch.nn.Module`.
This is a convenience method which combines
:func:`unit_scaling.analysis.example_batch`,
:func:`unit_scaling.transforms.track_scales` and
:func:`unit_scaling.analysis.plot`.
Warning: this method is experimental and may not work for a wide range of
models. It currently only supports models that use the following interface:
.. code-block:: python
output, loss = model(inputs, labels)
Future versions will support standard huggingface interfaces. For now we recommend
users with models providing different interfaces to re-implement this method for
their use case, based on the following template:
.. code-block:: python
inputs, attn_mask, labels = example_batch(
tokenizer, batch_size, seq_len, dataset_path, dataset_name
)
tracked_model = track_scales(model)
loss = ... # code to call model with (inputs, attn_mask, labels), returning loss
if backward:
loss.backward()
graph = tracked_model.scales_graph()
return plot(graph, **plot_kwargs)
Operations that don't output floating-point tensors are automatically pruned from
the visualised graph, as they are deemed unlikely to be relevant from the
perspective of model numerics.
Faint coloured horizontal lines for each row represent error bars indicating
the maximum and minimum values seen in each tensor during tracking.
Args:
model (nn.Module): the model to visualise
tokenizer (PreTrainedTokenizerBase): the tokenizer corresponding to the model.
batch_size (int): the batch size for the visualisation
seq_len (int): the sequence length for the visualisation
backward (bool, optional): visualise scales in the backward pass.
Defaults to True.
dataset_path (str, optional): huggingface path of the dataset to use for
visualisation. Defaults to "wikitext".
dataset_name (str, optional): huggingface name of the dataset to use for
visualisation. Defaults to "wikitext-103-v1".
plot_kwargs (Any): keyword args passed to :func:`unit_scaling.analysis.plot`.
Returns:
matplotlib.axes.Axes: the axes representing the generated plot.
"""
inputs, attn_mask, labels = example_batch(
tokenizer, batch_size, seq_len, dataset_path, dataset_name
)
tracked_model = track_scales(model.to("cpu"))
_, loss = tracked_model(inputs, labels)
if backward:
loss.backward()
graph = tracked_model.scales_graph()
return plot(graph, **plot_kwargs)
__all__ = generate__all__(__name__)