In [1]:
import torch
from torch import nn, Tensor
class Model(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(10, 4)
self.project = nn.Linear(4, 4)
self.unembed = nn.Linear(4, 10)
def forward(self, tokens: Tensor) -> Tensor:
logits = self.unembed(self.project(self.embed(tokens)))
return nn.functional.cross_entropy(logits, tokens)
torch.manual_seed(100)
module = Model()
inputs = torch.randint(0, 10, (3,))
Use tensor_tracker to capture forward pass activations and backward pass gradients from our toy model. By default, the tracker saves full tensors, as a list of tensor_tracker.Stash objects.
In [2]:
import tensor_tracker
with tensor_tracker.track(module) as tracker:
module(inputs).backward()
print(tracker)
Tracker(stashes=8, tracking=0)
Note that calls are only tracked within the with context. Then, the tracker behaves like a list of Stash objects, with attached name, value etc.
In [3]:
display(list(tracker))
# => [Stash(name="embed", type=nn.Embedding, grad=False, value=tensor(...)),
# ...]
[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4698, 1.2426, 0.5403, -1.1454],
[-0.8425, -0.6475, -0.2189, -1.1326],
[ 0.1268, 1.3564, 0.5632, -0.1039]])),
Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.6237, -0.1652, 0.3782, -0.8841],
[-0.9278, -0.2848, -0.8688, -0.4719],
[-0.3449, 0.3643, 0.3935, -0.6302]])),
Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2458, 1.0003, -0.8231, -0.1405, -0.2964, 0.5837, 0.2889, 0.2059,
-0.6114, -0.5916],
[-0.6345, 1.0882, -0.4304, -0.2196, -0.0426, 0.9428, 0.2051, 0.5897,
-0.2217, -0.9132],
[-0.0822, 0.9985, -0.7097, -0.3139, -0.4805, 0.6878, 0.2560, 0.3254,
-0.4447, -0.3332]])),
Stash(name='', type=<class '__main__.Model'>, grad=False, value=tensor(2.5663)),
Stash(name='', type=<class '__main__.Model'>, grad=True, value=(tensor(1.),)),
Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[ 0.0237, 0.0824, -0.3200, 0.0263, 0.0225, 0.0543, 0.0404, 0.0372,
0.0164, 0.0168],
[ 0.0139, 0.0779, 0.0171, 0.0211, 0.0251, 0.0673, 0.0322, -0.2860,
0.0210, 0.0105],
[-0.3066, 0.0787, 0.0143, 0.0212, 0.0179, 0.0577, 0.0374, 0.0401,
0.0186, 0.0208]]),)),
Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[-0.1755, 0.1306, 0.0443, -0.1823],
[ 0.1202, -0.0728, 0.0066, -0.0839],
[-0.1863, 0.0470, -0.1055, -0.0353]]),)),
Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=True, value=(tensor([[-0.0108, 0.1086, -0.1304, -0.0370],
[ 0.0534, -0.0029, 0.0078, -0.0074],
[-0.0829, 0.0152, -0.1170, -0.0625]]),))]
As a higher-level API, to_frame computes summary statistics, defaulting to torch.std.
In [4]:
display(tracker.to_frame())
| name | type | grad | std | |
|---|---|---|---|---|
| 0 | embed | torch.nn.modules.sparse.Embedding | False | 0.853265 |
| 1 | project | torch.nn.modules.linear.Linear | False | 0.494231 |
| 2 | unembed | torch.nn.modules.linear.Linear | False | 0.581503 |
| 3 | __main__.Model | False | NaN | |
| 4 | __main__.Model | True | NaN | |
| 5 | unembed | torch.nn.modules.linear.Linear | True | 0.105266 |
| 6 | project | torch.nn.modules.linear.Linear | True | 0.112392 |
| 7 | embed | torch.nn.modules.sparse.Embedding | True | 0.068816 |