3.1. unit_scaling
Unit-scaled versions of common torch.nn modules.
Functions
|
Construct a u-μP parameter object, an annotated |
Compute the residual tau ratios for the default transformer rule. |
|
|
[Experimental] Generate a plot visualising the scales in the forward (and optionally backward) pass of all tensors in an arbitrary |
Classes
|
Applies a unit-scaled 1D convolution to the incoming data. |
|
Computes a unit-scaled cross entropy loss between input logits and target. |
|
A |
|
A |
|
A unit-scaled implementation of Dropout. |
|
A unit-scaled lookup table that looks up embeddings in a fixed dictionary and size. |
|
Applies a unit-scaled Gaussian Error Linear Units function: |
|
Applies a unit-scaled Layer Normalization over a mini-batch of inputs. |
|
Applies a unit-scaled linear transformation to the incoming data. |
|
Applies a unit-scaled linear transformation to the incoming data, scaled appropriately for the final network output. |
|
A unit-scaled implementation of a multi-head self-attention layer. |
|
A unit-scaled implementation of an MLP layer using SwiGLU. |
|
Applies a unit-scaled RMS normalisation over trailing dimensions. |
|
Applies a unit-scaled Sigmoid Linear Unit function: |
|
Applies a unit-scaled Softmax function to an n-dimensional input Tensor. |
|
A unit-scaled implementation of a decoder-type transformer. |
|
A unit-scaled implementation of a PreNorm (see https://arxiv.org/abs/2002.04745) transformer layer. |
Modules
Core components for advanced library usage. |
|
Unit-scaled versions of common torch.nn.functional functions. |
|
Optimizer wrappers that apply scaling rules for u-muP. |
|
Extends |