3.1. unit_scaling

Unit-scaled versions of common torch.nn modules.

Functions

Parameter(data, mup_type[, mup_scaling_depth])

Construct a u-μP parameter object, an annotated torch.nn.Parameter.

transformer_residual_scaling_rule([...])

Compute the residual tau ratios for the default transformer rule.

visualiser(model, tokenizer, batch_size, seq_len)

[Experimental] Generate a plot visualising the scales in the forward (and optionally backward) pass of all tensors in an arbitrary torch.nn.Module.

Classes

Conv1d(in_channels, out_channels, kernel_size)

Applies a unit-scaled 1D convolution to the incoming data.

CrossEntropyLoss([mult, weight, ...])

Computes a unit-scaled cross entropy loss between input logits and target.

DepthModuleList(modules)

A torch.nn.ModuleList that automatically configures the depth for sake of scaling.

DepthSequential(*args)

A torch.nn.Sequential that automatically configures the depth for sake of scaling.

Dropout([p, inplace])

A unit-scaled implementation of Dropout.

Embedding(num_embeddings, embedding_dim[, ...])

A unit-scaled lookup table that looks up embeddings in a fixed dictionary and size.

GELU([mult, constraint, approximate])

Applies a unit-scaled Gaussian Error Linear Units function:

LayerNorm(normalized_shape[, eps, ...])

Applies a unit-scaled Layer Normalization over a mini-batch of inputs.

Linear(in_features, out_features[, bias, ...])

Applies a unit-scaled linear transformation to the incoming data.

LinearReadout(in_features, out_features[, ...])

Applies a unit-scaled linear transformation to the incoming data, scaled appropriately for the final network output.

MHSA(hidden_size, heads, is_causal[, ...])

A unit-scaled implementation of a multi-head self-attention layer.

MLP(hidden_size[, expansion_factor])

A unit-scaled implementation of an MLP layer using SwiGLU.

RMSNorm(normalized_shape[, eps, ...])

Applies a unit-scaled RMS normalisation over trailing dimensions.

SiLU([mult, constraint, inplace])

Applies a unit-scaled Sigmoid Linear Unit function:

Softmax(dim[, mult, constraint])

Applies a unit-scaled Softmax function to an n-dimensional input Tensor.

TransformerDecoder(hidden_size, vocab_size, ...)

A unit-scaled implementation of a decoder-type transformer.

TransformerLayer(hidden_size, heads, ...[, ...])

A unit-scaled implementation of a PreNorm (see https://arxiv.org/abs/2002.04745) transformer layer.

Modules

core

Core components for advanced library usage.

functional

Unit-scaled versions of common torch.nn.functional functions.

optim

Optimizer wrappers that apply scaling rules for u-muP.

parameter

Extends torch.nn.Parameter with attributes for u-μP.