Copyright (c) 2023 Graphcore Ltd. All rights reserved.

Usage example¶

Create a toy model to track:

In [1]:
import torch
from torch import nn, Tensor

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(10, 4)
        self.project = nn.Linear(4, 4)
        self.unembed = nn.Linear(4, 10)

    def forward(self, tokens: Tensor) -> Tensor:
        logits = self.unembed(self.project(self.embed(tokens)))
        return nn.functional.cross_entropy(logits, tokens)

torch.manual_seed(100)
module = Model()
inputs = torch.randint(0, 10, (3,))

Use tensor_tracker to capture forward pass activations and backward pass gradients from our toy model. By default, the tracker saves full tensors, as a list of tensor_tracker.Stash objects.

In [2]:
import tensor_tracker

with tensor_tracker.track(module) as tracker:
    module(inputs).backward()

print(tracker)
Tracker(stashes=8, tracking=0)

Note that calls are only tracked within the with context. Then, the tracker behaves like a list of Stash objects, with attached name, value etc.

In [3]:
display(list(tracker))
# => [Stash(name="embed", type=nn.Embedding, grad=False, value=tensor(...)),
#     ...]
[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4698,  1.2426,  0.5403, -1.1454],
         [-0.8425, -0.6475, -0.2189, -1.1326],
         [ 0.1268,  1.3564,  0.5632, -0.1039]])),
 Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.6237, -0.1652,  0.3782, -0.8841],
         [-0.9278, -0.2848, -0.8688, -0.4719],
         [-0.3449,  0.3643,  0.3935, -0.6302]])),
 Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2458,  1.0003, -0.8231, -0.1405, -0.2964,  0.5837,  0.2889,  0.2059,
          -0.6114, -0.5916],
         [-0.6345,  1.0882, -0.4304, -0.2196, -0.0426,  0.9428,  0.2051,  0.5897,
          -0.2217, -0.9132],
         [-0.0822,  0.9985, -0.7097, -0.3139, -0.4805,  0.6878,  0.2560,  0.3254,
          -0.4447, -0.3332]])),
 Stash(name='', type=<class '__main__.Model'>, grad=False, value=tensor(2.5663)),
 Stash(name='', type=<class '__main__.Model'>, grad=True, value=(tensor(1.),)),
 Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[ 0.0237,  0.0824, -0.3200,  0.0263,  0.0225,  0.0543,  0.0404,  0.0372,
           0.0164,  0.0168],
         [ 0.0139,  0.0779,  0.0171,  0.0211,  0.0251,  0.0673,  0.0322, -0.2860,
           0.0210,  0.0105],
         [-0.3066,  0.0787,  0.0143,  0.0212,  0.0179,  0.0577,  0.0374,  0.0401,
           0.0186,  0.0208]]),)),
 Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[-0.1755,  0.1306,  0.0443, -0.1823],
         [ 0.1202, -0.0728,  0.0066, -0.0839],
         [-0.1863,  0.0470, -0.1055, -0.0353]]),)),
 Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=True, value=(tensor([[-0.0108,  0.1086, -0.1304, -0.0370],
         [ 0.0534, -0.0029,  0.0078, -0.0074],
         [-0.0829,  0.0152, -0.1170, -0.0625]]),))]

As a higher-level API, to_frame computes summary statistics, defaulting to torch.std.

In [4]:
display(tracker.to_frame())
name type grad std
0 embed torch.nn.modules.sparse.Embedding False 0.853265
1 project torch.nn.modules.linear.Linear False 0.494231
2 unembed torch.nn.modules.linear.Linear False 0.581503
3 __main__.Model False NaN
4 __main__.Model True NaN
5 unembed torch.nn.modules.linear.Linear True 0.105266
6 project torch.nn.modules.linear.Linear True 0.112392
7 embed torch.nn.modules.sparse.Embedding True 0.068816