3.1.3. unit_scaling.visualiser

unit_scaling.visualiser(model: Module, tokenizer: PreTrainedTokenizerBase, batch_size: int, seq_len: int, backward: bool = True, dataset_path: str = 'wikitext', dataset_name: str = 'wikitext-103-v1', **plot_kwargs: Any) Axes[source]

[Experimental] Generate a plot visualising the scales in the forward (and optionally backward) pass of all tensors in an arbitrary torch.nn.Module.

This is a convenience method which combines unit_scaling.analysis.example_batch(), unit_scaling.transforms.track_scales() and unit_scaling.analysis.plot().

Warning: this method is experimental and may not work for a wide range of models. It currently only supports models that use the following interface:

output, loss = model(inputs, labels)

Future versions will support standard huggingface interfaces. For now we recommend users with models providing different interfaces to re-implement this method for their use case, based on the following template:

inputs, attn_mask, labels = example_batch(
    tokenizer, batch_size, seq_len, dataset_path, dataset_name
)
tracked_model = track_scales(model)
loss = ... # code to call model with (inputs, attn_mask, labels), returning loss
if backward:
    loss.backward()
graph = tracked_model.scales_graph()
return plot(graph, **plot_kwargs)

Operations that don’t output floating-point tensors are automatically pruned from the visualised graph, as they are deemed unlikely to be relevant from the perspective of model numerics.

Faint coloured horizontal lines for each row represent error bars indicating the maximum and minimum values seen in each tensor during tracking.

Parameters:
  • model (nn.Module) – the model to visualise

  • tokenizer (PreTrainedTokenizerBase) – the tokenizer corresponding to the model.

  • batch_size (int) – the batch size for the visualisation

  • seq_len (int) – the sequence length for the visualisation

  • backward (bool, optional) – visualise scales in the backward pass. Defaults to True.

  • dataset_path (str, optional) – huggingface path of the dataset to use for visualisation. Defaults to “wikitext”.

  • dataset_name (str, optional) – huggingface name of the dataset to use for visualisation. Defaults to “wikitext-103-v1”.

  • plot_kwargs (Any) – keyword args passed to unit_scaling.analysis.plot().

Returns:

the axes representing the generated plot.

Return type:

matplotlib.axes.Axes