3.8. unit_scaling.utils
Utility functions for developing unit-scaled models.
Functions
|
Given a nn.Module and dummy forward and backward tensors, generates code representing the module annotated with the scales (standard deviation) of each tensor in both forward and backward passes. |
Classes
|
Dataclass containing a pair of scalars, intended to represent the standard deviation of an arbitrary tensor in the forward and backward passes. |
|
Given a nn.Tensor, records its standard deviation in the forward and backward pass in the supplied ScalePair. |
|
Wraps an fx.GraphModule such than when executed it records the standard deviation of every intermediate nn.Tensor in the forward and backward pass. |