3.1. unit_scaling
Unit-scaled versions of common torch.nn modules.
Functions
  | 
Construct a u-μP parameter object, an annotated   | 
Compute the residual tau ratios for the default transformer rule.  | 
Classes
  | 
Applies a unit-scaled 1D convolution to the incoming data.  | 
  | 
Computes a unit-scaled cross entropy loss between input logits and target.  | 
  | 
A   | 
  | 
A   | 
  | 
A unit-scaled implementation of Dropout.  | 
  | 
A unit-scaled lookup table that looks up embeddings in a fixed dictionary and size.  | 
  | 
Applies a unit-scaled Gaussian Error Linear Units function:  | 
  | 
Applies a unit-scaled Layer Normalization over a mini-batch of inputs.  | 
  | 
Applies a unit-scaled linear transformation to the incoming data.  | 
  | 
Applies a unit-scaled linear transformation to the incoming data, scaled appropriately for the final network output.  | 
  | 
A unit-scaled implementation of a multi-head self-attention layer.  | 
  | 
A unit-scaled implementation of an MLP layer using SwiGLU.  | 
  | 
Applies a unit-scaled RMS normalisation over trailing dimensions.  | 
  | 
Applies a unit-scaled Sigmoid Linear Unit function:  | 
  | 
Applies a unit-scaled Softmax function to an n-dimensional input Tensor.  | 
  | 
A unit-scaled implementation of a decoder-type transformer.  | 
  | 
A unit-scaled implementation of a PreNorm (see https://arxiv.org/abs/2002.04745) transformer layer.  | 
Modules
Core components for advanced library usage.  | 
|
Unit-scaled versions of common torch.nn.functional functions.  | 
|
Optimizer wrappers that apply scaling rules for u-muP.  | 
|
Extends   |