4. API reference

unit-scaling is implemented using thin wrappers around existing torch.nn classes and functions. Documentation also inherits from the standard PyTorch docs, with modifications for scaling. Note that some docs may no longer be relevant but are nevertheless inherited.

The API is built to mirror torch.nn as closely as possible, such that PyTorch classes and functions can easily be swapped-out for their unit-scaled equivalents.

For PyTorch code which uses the following imports:

from torch import nn
from torch.nn import functional as F

Unit scaling can be applied by first adding:

import unit_scaling as uu
from unit_scaling import functional as U

and then replacing the letters nn with uu and F with U, for those classes/functions to be unit-scaled (assuming they are supported).

Click below for the full documentation:

unit_scaling

Unit-scaled versions of common torch.nn modules.

unit_scaling.analysis

Tools for analysing scale (and other metrics) within PyTorch models.

unit_scaling.constraints

Common scale-constraints used in unit-scaled operations.

unit_scaling.formats

Classes for simulating (non-standard) number formats.

unit_scaling.functional

Unit-scaled versions of common torch.nn.functional functions.

unit_scaling.optim

Optimizer wrappers that apply scaling rules for u-muP.

unit_scaling.scale

Operations to enable different scaling factors in the forward and backward passes.

unit_scaling.transforms

Useful torch dynamo transforms of modules for the sake of numerics and unit scaling.

unit_scaling.transforms.utils

Utilities for working with transforms.

unit_scaling.utils

Utility functions for developing unit-scaled models.

unit_scaling.core.functional

Core functionality for implementing unit_scaling.functional.