In [1]:
import torch
from torch import nn, Tensor
class Model(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(10, 4)
self.project = nn.Linear(4, 4)
self.unembed = nn.Linear(4, 10)
def forward(self, tokens: Tensor) -> Tensor:
logits = self.unembed(self.project(self.embed(tokens)))
return nn.functional.cross_entropy(logits, tokens)
torch.manual_seed(100)
module = Model()
inputs = torch.randint(0, 10, (3,))
Use tensor_tracker
to capture forward pass activations and backward pass gradients from our toy model. By default, the tracker saves full tensors, as a list of tensor_tracker.Stash
objects.
In [2]:
import tensor_tracker
with tensor_tracker.track(module) as tracker:
module(inputs).backward()
print(tracker)
Tracker(stashes=8, tracking=0)
Note that calls are only tracked within the with
context. Then, the tracker behaves like a list of Stash
objects, with attached name
, value
etc.
In [3]:
display(list(tracker))
# => [Stash(name="embed", type=nn.Embedding, grad=False, value=tensor(...)),
# ...]
[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4698, 1.2426, 0.5403, -1.1454], [-0.8425, -0.6475, -0.2189, -1.1326], [ 0.1268, 1.3564, 0.5632, -0.1039]])), Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.6237, -0.1652, 0.3782, -0.8841], [-0.9278, -0.2848, -0.8688, -0.4719], [-0.3449, 0.3643, 0.3935, -0.6302]])), Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2458, 1.0003, -0.8231, -0.1405, -0.2964, 0.5837, 0.2889, 0.2059, -0.6114, -0.5916], [-0.6345, 1.0882, -0.4304, -0.2196, -0.0426, 0.9428, 0.2051, 0.5897, -0.2217, -0.9132], [-0.0822, 0.9985, -0.7097, -0.3139, -0.4805, 0.6878, 0.2560, 0.3254, -0.4447, -0.3332]])), Stash(name='', type=<class '__main__.Model'>, grad=False, value=tensor(2.5663)), Stash(name='', type=<class '__main__.Model'>, grad=True, value=(tensor(1.),)), Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[ 0.0237, 0.0824, -0.3200, 0.0263, 0.0225, 0.0543, 0.0404, 0.0372, 0.0164, 0.0168], [ 0.0139, 0.0779, 0.0171, 0.0211, 0.0251, 0.0673, 0.0322, -0.2860, 0.0210, 0.0105], [-0.3066, 0.0787, 0.0143, 0.0212, 0.0179, 0.0577, 0.0374, 0.0401, 0.0186, 0.0208]]),)), Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=True, value=(tensor([[-0.1755, 0.1306, 0.0443, -0.1823], [ 0.1202, -0.0728, 0.0066, -0.0839], [-0.1863, 0.0470, -0.1055, -0.0353]]),)), Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=True, value=(tensor([[-0.0108, 0.1086, -0.1304, -0.0370], [ 0.0534, -0.0029, 0.0078, -0.0074], [-0.0829, 0.0152, -0.1170, -0.0625]]),))]
As a higher-level API, to_frame
computes summary statistics, defaulting to torch.std
.
In [4]:
display(tracker.to_frame())
name | type | grad | std | |
---|---|---|---|---|
0 | embed | torch.nn.modules.sparse.Embedding | False | 0.853265 |
1 | project | torch.nn.modules.linear.Linear | False | 0.494231 |
2 | unembed | torch.nn.modules.linear.Linear | False | 0.581503 |
3 | __main__.Model | False | NaN | |
4 | __main__.Model | True | NaN | |
5 | unembed | torch.nn.modules.linear.Linear | True | 0.105266 |
6 | project | torch.nn.modules.linear.Linear | True | 0.112392 |
7 | embed | torch.nn.modules.sparse.Embedding | True | 0.068816 |