3.2. unit_scaling.analysis

Tools for analysing scale (and other metrics) within PyTorch models.

Functions

example_batch(tokenizer, batch_size, seq_len)

Generates a batch of token IDs from a given dataset, along with an attention mask and labels (just the shifted token IDs).

graph_to_dataframe(g)

Converts a torch.fx.Graph with annotated unit_scaling.transforms.Metrics into a pandas.DataFrame.

plot(g[, title, metric, prune_same_scale, ...])

Generate a matplotlib plot visualising the scales in the forward (and optionally backward) pass of all tensors in an FX graph.

visualiser(model, tokenizer, batch_size, seq_len)

[Experimental] Generate a plot visualising the scales in the forward (and optionally backward) pass of all tensors in an arbitrary torch.nn.Module.