3. API reference
unit-scaling
is implemented using thin wrappers around existing torch.nn
classes and functions. Documentation also inherits from the standard PyTorch docs, with
modifications for scaling. Note that some docs may no longer be relevant but are
nevertheless inherited.
The API is built to mirror torch.nn
as closely as possible, such that PyTorch
classes and functions can easily be swapped-out for their unit-scaled equivalents.
For PyTorch code which uses the following imports:
from torch import nn
from torch.nn import functional as F
Unit scaling can be applied by first adding:
import unit_scaling as uu
from unit_scaling import functional as U
and then replacing the letters nn
with uu
and
F
with U
, for those classes/functions to be unit-scaled
(assuming they are supported).
Click below for the full documentation:
Unit-scaled versions of common torch.nn modules. |
|
Tools for analysing scale (and other metrics) within PyTorch models. |
|
Common scale-constraints used in unit-scaled operations. |
|
Classes for simulating (non-standard) number formats. |
|
Unit-scaled versions of common torch.nn.functional functions. |
|
Optimizer wrappers that apply scaling rules for u-muP. |
|
Operations to enable different scaling factors in the forward and backward passes. |
|
Useful torch dynamo transforms of modules for the sake of numerics and unit scaling. |
|
Utilities for working with transforms. |
|
Utility functions for developing unit-scaled models. |
|
Core functionality for implementing unit_scaling.functional. |