Module tensor_tracker.core
Utility for tracking activations and gradients at nn.Module
outputs.
Use track()
to start tracking a module & submodules. Then use the original module
as usual. Your Tracker
will be filled with a list of Stash
es, containing
copies of fwd/bwd tensors at (sub)module outputs. (Beware, this can consume
a lot of memory.)
Usage (notebook):
with tensor_tracker.track(model) as tracker:
model(inputs).backward()
print(list(tracker))
# => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)),
# ...]
display(tracker.to_frame()) # requires 'pandas'
Advanced usage:
-
Filter modules based on name:
track(include="<regex>", exclude="<regex>")
-
Pre-transform tracked tensors to save memory:
track(stash_value=lambda t: t.std().detach().cpu())
-
Customise tracked state:
track(stash=lambda event: ...)
-
Manually register/unregister hooks:
tracker = Tracker(); tracker.register(...); tracker.unregister()
See also: example of visualising transformer activations & gradients using UMAP.
Expand source code
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""Utility for tracking activations and gradients at `nn.Module` outputs.
Use `track` to start tracking a module & submodules. Then use the original module
as usual. Your `Tracker` will be filled with a list of `Stash`es, containing
copies of fwd/bwd tensors at (sub)module outputs. (Beware, this can consume
a lot of memory.)
Usage ([notebook](usage.html)):
```
with tensor_tracker.track(model) as tracker:
model(inputs).backward()
print(list(tracker))
# => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)),
# ...]
display(tracker.to_frame()) # requires 'pandas'
```
Advanced usage:
- Filter modules based on name:
`track(include="<regex>", exclude="<regex>")`
- Pre-transform tracked tensors to save memory:
`track(stash_value=lambda t: t.std().detach().cpu())`
- Customise tracked state:
`track(stash=lambda event: ...)`
- Manually register/unregister hooks:
`tracker = Tracker(); tracker.register(...); tracker.unregister()`
See also: [example of
visualising transformer activations & gradients using UMAP](example.html).
"""
import dataclasses
import re
from dataclasses import dataclass
from functools import partial
from types import TracebackType
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Pattern,
Tuple,
Type,
Union,
)
import torch.utils.hooks
from torch import Tensor, nn
@dataclass
class Event:
name: str
type: Type[nn.Module]
grad: bool
value: Any
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
@dataclass
class Stash:
name: str
type: Type[nn.Module]
grad: bool
value: Any # output(s) or grad_output(s)
@property
def first_value(self) -> Any:
def _value(v: Any) -> Any:
if isinstance(v, (tuple, list)) and len(v) >= 1:
return _value(v[0])
return v
return _value(self.value)
StashFn = Callable[[Event], Stash]
StashValueFn = Callable[[Tensor], Any]
def rmap_tensor(value: Any, fn: Callable[[Tensor], Any]) -> Any:
if isinstance(value, (tuple, list)):
return type(value)(rmap_tensor(a, fn) for a in value)
if isinstance(value, dict):
return {rmap_tensor(k, fn): rmap_tensor(a, fn) for k, a in value.items()}
if dataclasses.is_dataclass(value):
return type(value)(**{k: rmap_tensor(v, fn) for k, v in value.__dict__.items()})
if isinstance(value, Tensor):
return fn(value)
return value
def default_stash_value(tensor: Tensor) -> Tensor:
return tensor.detach().cpu().clone()
def default_stash(event: Event, stash_value: StashValueFn) -> Stash:
return Stash(
event.name, event.type, event.grad, rmap_tensor(event.value, stash_value)
)
def get_stash_fn(
stash_value: Optional[StashValueFn] = None, stash: Optional[StashFn] = None
) -> StashFn:
if stash_value and stash:
raise ValueError("Cannot provide StashValueFn and StashFn to get_stash_fn()")
if stash:
return stash
return partial(default_stash, stash_value=stash_value or default_stash_value)
NamePattern = Union[None, Pattern[str], str]
class Tracker:
def __init__(self, stash: StashFn):
self.stashes: List[Stash] = []
self._handles: List[torch.utils.hooks.RemovableHandle] = []
self._stash = stash
# Registration/tracking
def __enter__(self) -> "Tracker":
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.unregister()
def clear(self) -> None:
self.stashes.clear()
def register(self, module: nn.Module, name: str = "", grad: bool = True) -> None:
self._handles.append(
module.register_forward_hook(
partial(self._forward_hook, name=name), with_kwargs=True
)
)
if grad:
self._handles.append(
module.register_full_backward_pre_hook(
partial(self._backward_hook, name=name)
)
)
def register_all(
self,
module: nn.Module,
grad: bool = True,
include: NamePattern = None,
exclude: NamePattern = None,
) -> None:
include = re.compile(include) if isinstance(include, str) else include
exclude = re.compile(exclude) if isinstance(exclude, str) else exclude
for name, child in module.named_modules():
if ((not include) or include.search(name)) and not (
exclude and exclude.search(name)
):
self.register(child, name, grad=grad)
def unregister(self) -> None:
for handle in self._handles:
handle.remove()
self._handles.clear()
def _forward_hook(
self,
module: nn.Module,
args: Tuple[Any],
kwargs: Dict[str, Any],
output: Any,
*,
name: str,
) -> None:
self.stashes.append(
self._stash(Event(name, type(module), False, output, args, kwargs))
)
def _backward_hook(self, module: nn.Module, grad_output: Any, *, name: str) -> None:
self.stashes.append(
self._stash(Event(name, type(module), True, grad_output, (), {}))
)
# Read results
def __str__(self) -> str:
return f"Tracker(stashes={len(self)}, tracking={len(self._handles)})"
def __iter__(self) -> Iterator[Stash]:
return iter(self.stashes)
def __getitem__(self, index: int) -> Stash:
return self.stashes[index]
def __len__(self) -> int:
return len(self.stashes)
def to_frame(
self,
stat: Callable[[Tensor], Tensor] = torch.std,
stat_name: Optional[str] = None,
) -> "pandas.DataFrame": # type:ignore[name-defined] # NOQA: F821
import pandas
column_name = (
getattr(stat, "__name__", "value") if stat_name is None else stat_name
)
def to_item(stash: Stash) -> Dict[str, Any]:
d = stash.__dict__.copy()
d.pop("value")
v = stash.first_value
d[column_name] = stat(v).item() if isinstance(v, Tensor) else None
d["type"] = f"{stash.type.__module__}.{stash.type.__name__}"
return d
return pandas.DataFrame.from_dict(map(to_item, self)) # type:ignore[arg-type]
def track(
module: nn.Module,
grad: bool = True,
include: NamePattern = None,
exclude: NamePattern = None,
stash_value: Optional[StashValueFn] = None,
stash: Optional[StashFn] = None,
) -> Tracker:
tracker = Tracker(get_stash_fn(stash_value=stash_value, stash=stash))
tracker.register_all(module, grad=grad, include=include, exclude=exclude)
return tracker
track.__doc__ = __doc__
__all__ = [
"Event",
"Stash",
"StashFn",
"StashValueFn",
"rmap_tensor",
"default_stash_value",
"default_stash",
"get_stash_fn",
"Tracker",
"track",
]
Functions
def default_stash(event: Event, stash_value: Callable[[torch.Tensor], Any]) ‑> Stash
-
Expand source code
def default_stash(event: Event, stash_value: StashValueFn) -> Stash: return Stash( event.name, event.type, event.grad, rmap_tensor(event.value, stash_value) )
def default_stash_value(tensor: torch.Tensor) ‑> torch.Tensor
-
Expand source code
def default_stash_value(tensor: Tensor) -> Tensor: return tensor.detach().cpu().clone()
def get_stash_fn(stash_value: Optional[Callable[[torch.Tensor], Any]] = None, stash: Optional[Callable[[Event], Stash]] = None) ‑> Callable[[Event], Stash]
-
Expand source code
def get_stash_fn( stash_value: Optional[StashValueFn] = None, stash: Optional[StashFn] = None ) -> StashFn: if stash_value and stash: raise ValueError("Cannot provide StashValueFn and StashFn to get_stash_fn()") if stash: return stash return partial(default_stash, stash_value=stash_value or default_stash_value)
def rmap_tensor(value: Any, fn: Callable[[torch.Tensor], Any]) ‑> Any
-
Expand source code
def rmap_tensor(value: Any, fn: Callable[[Tensor], Any]) -> Any: if isinstance(value, (tuple, list)): return type(value)(rmap_tensor(a, fn) for a in value) if isinstance(value, dict): return {rmap_tensor(k, fn): rmap_tensor(a, fn) for k, a in value.items()} if dataclasses.is_dataclass(value): return type(value)(**{k: rmap_tensor(v, fn) for k, v in value.__dict__.items()}) if isinstance(value, Tensor): return fn(value) return value
def track(module: torch.nn.modules.module.Module, grad: bool = True, include: Union[ForwardRef(None), Pattern[str], str] = None, exclude: Union[ForwardRef(None), Pattern[str], str] = None, stash_value: Optional[Callable[[torch.Tensor], Any]] = None, stash: Optional[Callable[[Event], Stash]] = None) ‑> Tracker
-
Utility for tracking activations and gradients at
nn.Module
outputs.Use
track()
to start tracking a module & submodules. Then use the original module as usual. YourTracker
will be filled with a list ofStash
es, containing copies of fwd/bwd tensors at (sub)module outputs. (Beware, this can consume a lot of memory.)Usage (notebook):
with tensor_tracker.track(model) as tracker: model(inputs).backward() print(list(tracker)) # => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)), # ...] display(tracker.to_frame()) # requires 'pandas'
Advanced usage:
-
Filter modules based on name:
track(include="<regex>", exclude="<regex>")
-
Pre-transform tracked tensors to save memory:
track(stash_value=lambda t: t.std().detach().cpu())
-
Customise tracked state:
track(stash=lambda event: ...)
-
Manually register/unregister hooks:
tracker = Tracker(); tracker.register(...); tracker.unregister()
See also: example of visualising transformer activations & gradients using UMAP.
Expand source code
def track( module: nn.Module, grad: bool = True, include: NamePattern = None, exclude: NamePattern = None, stash_value: Optional[StashValueFn] = None, stash: Optional[StashFn] = None, ) -> Tracker: tracker = Tracker(get_stash_fn(stash_value=stash_value, stash=stash)) tracker.register_all(module, grad=grad, include=include, exclude=exclude) return tracker
-
Classes
class Event (name: str, type: Type[torch.nn.modules.module.Module], grad: bool, value: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any])
-
Event(name: str, type: Type[torch.nn.modules.module.Module], grad: bool, value: Any, args: Tuple[Any, …], kwargs: Dict[str, Any])
Expand source code
@dataclass class Event: name: str type: Type[nn.Module] grad: bool value: Any args: Tuple[Any, ...] kwargs: Dict[str, Any]
Class variables
var args : Tuple[Any, ...]
var grad : bool
var kwargs : Dict[str, Any]
var name : str
var type : Type[torch.nn.modules.module.Module]
var value : Any
class Stash (name: str, type: Type[torch.nn.modules.module.Module], grad: bool, value: Any)
-
Stash(name: str, type: Type[torch.nn.modules.module.Module], grad: bool, value: Any)
Expand source code
@dataclass class Stash: name: str type: Type[nn.Module] grad: bool value: Any # output(s) or grad_output(s) @property def first_value(self) -> Any: def _value(v: Any) -> Any: if isinstance(v, (tuple, list)) and len(v) >= 1: return _value(v[0]) return v return _value(self.value)
Class variables
var grad : bool
var name : str
var type : Type[torch.nn.modules.module.Module]
var value : Any
Instance variables
var first_value : Any
-
Expand source code
@property def first_value(self) -> Any: def _value(v: Any) -> Any: if isinstance(v, (tuple, list)) and len(v) >= 1: return _value(v[0]) return v return _value(self.value)
class Tracker (stash: Callable[[Event], Stash])
-
Expand source code
class Tracker: def __init__(self, stash: StashFn): self.stashes: List[Stash] = [] self._handles: List[torch.utils.hooks.RemovableHandle] = [] self._stash = stash # Registration/tracking def __enter__(self) -> "Tracker": return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: self.unregister() def clear(self) -> None: self.stashes.clear() def register(self, module: nn.Module, name: str = "", grad: bool = True) -> None: self._handles.append( module.register_forward_hook( partial(self._forward_hook, name=name), with_kwargs=True ) ) if grad: self._handles.append( module.register_full_backward_pre_hook( partial(self._backward_hook, name=name) ) ) def register_all( self, module: nn.Module, grad: bool = True, include: NamePattern = None, exclude: NamePattern = None, ) -> None: include = re.compile(include) if isinstance(include, str) else include exclude = re.compile(exclude) if isinstance(exclude, str) else exclude for name, child in module.named_modules(): if ((not include) or include.search(name)) and not ( exclude and exclude.search(name) ): self.register(child, name, grad=grad) def unregister(self) -> None: for handle in self._handles: handle.remove() self._handles.clear() def _forward_hook( self, module: nn.Module, args: Tuple[Any], kwargs: Dict[str, Any], output: Any, *, name: str, ) -> None: self.stashes.append( self._stash(Event(name, type(module), False, output, args, kwargs)) ) def _backward_hook(self, module: nn.Module, grad_output: Any, *, name: str) -> None: self.stashes.append( self._stash(Event(name, type(module), True, grad_output, (), {})) ) # Read results def __str__(self) -> str: return f"Tracker(stashes={len(self)}, tracking={len(self._handles)})" def __iter__(self) -> Iterator[Stash]: return iter(self.stashes) def __getitem__(self, index: int) -> Stash: return self.stashes[index] def __len__(self) -> int: return len(self.stashes) def to_frame( self, stat: Callable[[Tensor], Tensor] = torch.std, stat_name: Optional[str] = None, ) -> "pandas.DataFrame": # type:ignore[name-defined] # NOQA: F821 import pandas column_name = ( getattr(stat, "__name__", "value") if stat_name is None else stat_name ) def to_item(stash: Stash) -> Dict[str, Any]: d = stash.__dict__.copy() d.pop("value") v = stash.first_value d[column_name] = stat(v).item() if isinstance(v, Tensor) else None d["type"] = f"{stash.type.__module__}.{stash.type.__name__}" return d return pandas.DataFrame.from_dict(map(to_item, self)) # type:ignore[arg-type]
Methods
def clear(self) ‑> None
-
Expand source code
def clear(self) -> None: self.stashes.clear()
def register(self, module: torch.nn.modules.module.Module, name: str = '', grad: bool = True) ‑> None
-
Expand source code
def register(self, module: nn.Module, name: str = "", grad: bool = True) -> None: self._handles.append( module.register_forward_hook( partial(self._forward_hook, name=name), with_kwargs=True ) ) if grad: self._handles.append( module.register_full_backward_pre_hook( partial(self._backward_hook, name=name) ) )
def register_all(self, module: torch.nn.modules.module.Module, grad: bool = True, include: Union[ForwardRef(None), Pattern[str], str] = None, exclude: Union[ForwardRef(None), Pattern[str], str] = None) ‑> None
-
Expand source code
def register_all( self, module: nn.Module, grad: bool = True, include: NamePattern = None, exclude: NamePattern = None, ) -> None: include = re.compile(include) if isinstance(include, str) else include exclude = re.compile(exclude) if isinstance(exclude, str) else exclude for name, child in module.named_modules(): if ((not include) or include.search(name)) and not ( exclude and exclude.search(name) ): self.register(child, name, grad=grad)
def to_frame(self, stat: Callable[[torch.Tensor], torch.Tensor] = <built-in method std of type object>, stat_name: Optional[str] = None) ‑> pandas.DataFrame
-
Expand source code
def to_frame( self, stat: Callable[[Tensor], Tensor] = torch.std, stat_name: Optional[str] = None, ) -> "pandas.DataFrame": # type:ignore[name-defined] # NOQA: F821 import pandas column_name = ( getattr(stat, "__name__", "value") if stat_name is None else stat_name ) def to_item(stash: Stash) -> Dict[str, Any]: d = stash.__dict__.copy() d.pop("value") v = stash.first_value d[column_name] = stat(v).item() if isinstance(v, Tensor) else None d["type"] = f"{stash.type.__module__}.{stash.type.__name__}" return d return pandas.DataFrame.from_dict(map(to_item, self)) # type:ignore[arg-type]
def unregister(self) ‑> None
-
Expand source code
def unregister(self) -> None: for handle in self._handles: handle.remove() self._handles.clear()