3.1.3. unit_scaling.visualiser
- unit_scaling.visualiser(model: Module, tokenizer: PreTrainedTokenizerBase, batch_size: int, seq_len: int, backward: bool = True, dataset_path: str = 'wikitext', dataset_name: str = 'wikitext-103-v1', **plot_kwargs: Any) Axes [source]
[Experimental] Generate a plot visualising the scales in the forward (and optionally backward) pass of all tensors in an arbitrary
torch.nn.Module
.This is a convenience method which combines
unit_scaling.analysis.example_batch()
,unit_scaling.transforms.track_scales()
andunit_scaling.analysis.plot()
.Warning: this method is experimental and may not work for a wide range of models. It currently only supports models that use the following interface:
output, loss = model(inputs, labels)
Future versions will support standard huggingface interfaces. For now we recommend users with models providing different interfaces to re-implement this method for their use case, based on the following template:
inputs, attn_mask, labels = example_batch( tokenizer, batch_size, seq_len, dataset_path, dataset_name ) tracked_model = track_scales(model) loss = ... # code to call model with (inputs, attn_mask, labels), returning loss if backward: loss.backward() graph = tracked_model.scales_graph() return plot(graph, **plot_kwargs)
Operations that don’t output floating-point tensors are automatically pruned from the visualised graph, as they are deemed unlikely to be relevant from the perspective of model numerics.
Faint coloured horizontal lines for each row represent error bars indicating the maximum and minimum values seen in each tensor during tracking.
- Parameters:
model (nn.Module) – the model to visualise
tokenizer (PreTrainedTokenizerBase) – the tokenizer corresponding to the model.
batch_size (int) – the batch size for the visualisation
seq_len (int) – the sequence length for the visualisation
backward (bool, optional) – visualise scales in the backward pass. Defaults to True.
dataset_path (str, optional) – huggingface path of the dataset to use for visualisation. Defaults to “wikitext”.
dataset_name (str, optional) – huggingface name of the dataset to use for visualisation. Defaults to “wikitext-103-v1”.
plot_kwargs (Any) – keyword args passed to
unit_scaling.analysis.plot()
.
- Returns:
the axes representing the generated plot.
- Return type:
matplotlib.axes.Axes