Package tensor_tracker
Utility for tracking activations and gradients at nn.Module
outputs.
Use track()
to start tracking a module & submodules. Then use the original module
as usual. Your Tracker
will be filled with a list of Stash
es, containing
copies of fwd/bwd tensors at (sub)module outputs. (Beware, this can consume
a lot of memory.)
Usage (notebook):
with tensor_tracker.track(model) as tracker:
model(inputs).backward()
print(list(tracker))
# => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)),
# ...]
display(tracker.to_frame()) # requires 'pandas'
Advanced usage:
-
Filter modules based on name:
track(include="<regex>", exclude="<regex>")
-
Pre-transform tracked tensors to save memory:
track(stash_value=lambda t: t.std().detach().cpu())
-
Customise tracked state:
track(stash=lambda event: ...)
-
Manually register/unregister hooks:
tracker = Tracker(); tracker.register(...); tracker.unregister()
See also: example of visualising transformer activations & gradients using UMAP.
Expand source code
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .core import * # NOQA: F401 F403
from .core import __all__, __doc__ # NOQA: F401
Sub-modules
tensor_tracker.core
-
Utility for tracking activations and gradients at
nn.Module
outputs …
Functions
def default_stash(event: Event, stash_value: Callable[[torch.Tensor], Any]) ‑> Stash
-
Expand source code
def default_stash(event: Event, stash_value: StashValueFn) -> Stash: return Stash( event.name, event.type, event.grad, rmap_tensor(event.value, stash_value) )
def default_stash_value(tensor: torch.Tensor) ‑> torch.Tensor
-
Expand source code
def default_stash_value(tensor: Tensor) -> Tensor: return tensor.detach().cpu().clone()
def get_stash_fn(stash_value: Optional[Callable[[torch.Tensor], Any]] = None, stash: Optional[Callable[[Event], Stash]] = None) ‑> Callable[[Event], Stash]
-
Expand source code
def get_stash_fn( stash_value: Optional[StashValueFn] = None, stash: Optional[StashFn] = None ) -> StashFn: if stash_value and stash: raise ValueError("Cannot provide StashValueFn and StashFn to get_stash_fn()") if stash: return stash return partial(default_stash, stash_value=stash_value or default_stash_value)
def rmap_tensor(value: Any, fn: Callable[[torch.Tensor], Any]) ‑> Any
-
Expand source code
def rmap_tensor(value: Any, fn: Callable[[Tensor], Any]) -> Any: if isinstance(value, (tuple, list)): return type(value)(rmap_tensor(a, fn) for a in value) if isinstance(value, dict): return {rmap_tensor(k, fn): rmap_tensor(a, fn) for k, a in value.items()} if dataclasses.is_dataclass(value): return type(value)(**{k: rmap_tensor(v, fn) for k, v in value.__dict__.items()}) if isinstance(value, Tensor): return fn(value) return value
def track(module: torch.nn.modules.module.Module, grad: bool = True, include: Union[ForwardRef(None), Pattern[str], str] = None, exclude: Union[ForwardRef(None), Pattern[str], str] = None, stash_value: Optional[Callable[[torch.Tensor], Any]] = None, stash: Optional[Callable[[Event], Stash]] = None) ‑> Tracker
-
Utility for tracking activations and gradients at
nn.Module
outputs.Use
track()
to start tracking a module & submodules. Then use the original module as usual. YourTracker
will be filled with a list ofStash
es, containing copies of fwd/bwd tensors at (sub)module outputs. (Beware, this can consume a lot of memory.)Usage (notebook):
with tensor_tracker.track(model) as tracker: model(inputs).backward() print(list(tracker)) # => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)), # ...] display(tracker.to_frame()) # requires 'pandas'
Advanced usage:
-
Filter modules based on name:
track(include="<regex>", exclude="<regex>")
-
Pre-transform tracked tensors to save memory:
track(stash_value=lambda t: t.std().detach().cpu())
-
Customise tracked state:
track(stash=lambda event: ...)
-
Manually register/unregister hooks:
tracker = Tracker(); tracker.register(...); tracker.unregister()
See also: example of visualising transformer activations & gradients using UMAP.
Expand source code
def track( module: nn.Module, grad: bool = True, include: NamePattern = None, exclude: NamePattern = None, stash_value: Optional[StashValueFn] = None, stash: Optional[StashFn] = None, ) -> Tracker: tracker = Tracker(get_stash_fn(stash_value=stash_value, stash=stash)) tracker.register_all(module, grad=grad, include=include, exclude=exclude) return tracker
-
Classes
class Event (name: str, type: Type[torch.nn.modules.module.Module], grad: bool, value: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any])
-
Event(name: str, type: Type[torch.nn.modules.module.Module], grad: bool, value: Any, args: Tuple[Any, …], kwargs: Dict[str, Any])
Expand source code
@dataclass class Event: name: str type: Type[nn.Module] grad: bool value: Any args: Tuple[Any, ...] kwargs: Dict[str, Any]
Class variables
var args : Tuple[Any, ...]
var grad : bool
var kwargs : Dict[str, Any]
var name : str
var type : Type[torch.nn.modules.module.Module]
var value : Any
class Stash (name: str, type: Type[torch.nn.modules.module.Module], grad: bool, value: Any)
-
Stash(name: str, type: Type[torch.nn.modules.module.Module], grad: bool, value: Any)
Expand source code
@dataclass class Stash: name: str type: Type[nn.Module] grad: bool value: Any # output(s) or grad_output(s) @property def first_value(self) -> Any: def _value(v: Any) -> Any: if isinstance(v, (tuple, list)) and len(v) >= 1: return _value(v[0]) return v return _value(self.value)
Class variables
var grad : bool
var name : str
var type : Type[torch.nn.modules.module.Module]
var value : Any
Instance variables
var first_value : Any
-
Expand source code
@property def first_value(self) -> Any: def _value(v: Any) -> Any: if isinstance(v, (tuple, list)) and len(v) >= 1: return _value(v[0]) return v return _value(self.value)
class Tracker (stash: Callable[[Event], Stash])
-
Expand source code
class Tracker: def __init__(self, stash: StashFn): self.stashes: List[Stash] = [] self._handles: List[torch.utils.hooks.RemovableHandle] = [] self._stash = stash # Registration/tracking def __enter__(self) -> "Tracker": return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: self.unregister() def clear(self) -> None: self.stashes.clear() def register(self, module: nn.Module, name: str = "", grad: bool = True) -> None: self._handles.append( module.register_forward_hook( partial(self._forward_hook, name=name), with_kwargs=True ) ) if grad: self._handles.append( module.register_full_backward_pre_hook( partial(self._backward_hook, name=name) ) ) def register_all( self, module: nn.Module, grad: bool = True, include: NamePattern = None, exclude: NamePattern = None, ) -> None: include = re.compile(include) if isinstance(include, str) else include exclude = re.compile(exclude) if isinstance(exclude, str) else exclude for name, child in module.named_modules(): if ((not include) or include.search(name)) and not ( exclude and exclude.search(name) ): self.register(child, name, grad=grad) def unregister(self) -> None: for handle in self._handles: handle.remove() self._handles.clear() def _forward_hook( self, module: nn.Module, args: Tuple[Any], kwargs: Dict[str, Any], output: Any, *, name: str, ) -> None: self.stashes.append( self._stash(Event(name, type(module), False, output, args, kwargs)) ) def _backward_hook(self, module: nn.Module, grad_output: Any, *, name: str) -> None: self.stashes.append( self._stash(Event(name, type(module), True, grad_output, (), {})) ) # Read results def __str__(self) -> str: return f"Tracker(stashes={len(self)}, tracking={len(self._handles)})" def __iter__(self) -> Iterator[Stash]: return iter(self.stashes) def __getitem__(self, index: int) -> Stash: return self.stashes[index] def __len__(self) -> int: return len(self.stashes) def to_frame( self, stat: Callable[[Tensor], Tensor] = torch.std, stat_name: Optional[str] = None, ) -> "pandas.DataFrame": # type:ignore[name-defined] # NOQA: F821 import pandas column_name = ( getattr(stat, "__name__", "value") if stat_name is None else stat_name ) def to_item(stash: Stash) -> Dict[str, Any]: d = stash.__dict__.copy() d.pop("value") v = stash.first_value d[column_name] = stat(v).item() if isinstance(v, Tensor) else None d["type"] = f"{stash.type.__module__}.{stash.type.__name__}" return d return pandas.DataFrame.from_dict(map(to_item, self)) # type:ignore[arg-type]
Methods
def clear(self) ‑> None
-
Expand source code
def clear(self) -> None: self.stashes.clear()
def register(self, module: torch.nn.modules.module.Module, name: str = '', grad: bool = True) ‑> None
-
Expand source code
def register(self, module: nn.Module, name: str = "", grad: bool = True) -> None: self._handles.append( module.register_forward_hook( partial(self._forward_hook, name=name), with_kwargs=True ) ) if grad: self._handles.append( module.register_full_backward_pre_hook( partial(self._backward_hook, name=name) ) )
def register_all(self, module: torch.nn.modules.module.Module, grad: bool = True, include: Union[ForwardRef(None), Pattern[str], str] = None, exclude: Union[ForwardRef(None), Pattern[str], str] = None) ‑> None
-
Expand source code
def register_all( self, module: nn.Module, grad: bool = True, include: NamePattern = None, exclude: NamePattern = None, ) -> None: include = re.compile(include) if isinstance(include, str) else include exclude = re.compile(exclude) if isinstance(exclude, str) else exclude for name, child in module.named_modules(): if ((not include) or include.search(name)) and not ( exclude and exclude.search(name) ): self.register(child, name, grad=grad)
def to_frame(self, stat: Callable[[torch.Tensor], torch.Tensor] = <built-in method std of type object>, stat_name: Optional[str] = None) ‑> pandas.DataFrame
-
Expand source code
def to_frame( self, stat: Callable[[Tensor], Tensor] = torch.std, stat_name: Optional[str] = None, ) -> "pandas.DataFrame": # type:ignore[name-defined] # NOQA: F821 import pandas column_name = ( getattr(stat, "__name__", "value") if stat_name is None else stat_name ) def to_item(stash: Stash) -> Dict[str, Any]: d = stash.__dict__.copy() d.pop("value") v = stash.first_value d[column_name] = stat(v).item() if isinstance(v, Tensor) else None d["type"] = f"{stash.type.__module__}.{stash.type.__name__}" return d return pandas.DataFrame.from_dict(map(to_item, self)) # type:ignore[arg-type]
def unregister(self) ‑> None
-
Expand source code
def unregister(self) -> None: for handle in self._handles: handle.remove() self._handles.clear()