3.1.22. unit_scaling.functional

Unit-scaled versions of common torch.nn.functional functions.

Functions

add(input, other[, constraint, alpha, out])

Applies a unit-scaled addition.

conv1d(input, weight[, bias, stride, ...])

Applies a unit-scaled 1D convolution.

cross_entropy(input, target[, weight, ...])

Computes the unit-scaled cross entropy loss between input logits and target.

dropout(input[, p, training, inplace])

Applies a unit-scaled dropout function.

embedding(input, weight[, padding_idx, ...])

A unit-scaled lookup table that looks up embeddings in a fixed dictionaryand size.

gelu(input[, mult, constraint, approximate])

Applies a unit-scaled GELU function.

layer_norm(input, normalized_shape[, ...])

Applies a unit-scaled Layer Normalization for last certain number of dimensions.

linear(input, weight, bias[, constraint, ...])

Applies a unit-scaled linear transformation.

linear_readout(input, weight, bias[, constraint])

Applies a unit-scaled linear transformation, for the final network output.

matmul(left, right[, constraint])

A unit-scaled matrix product of two tensors.

mse_loss(input, target[, size_average, ...])

Computes the unit-scaled element-wise mean squared error.

residual_add(residual, skip[, tau])

Adds a residual connection and skip connection together, with a relative weighting tau applied to the residual branch.

residual_apply(fn, input[, tau])

Apply a weighted residual branch, maintaining unit scale.

residual_split(input[, tau])

Splits a tensor into an residual and skip tensor, prior to being used in a residual layer, with a relative weighting tau applied to the residual branch.

rms_norm(input, normalized_shape[, weight, eps])

Apply unit-scaled RMS Normalization for last certain number of dimensions.

scaled_dot_product_attention(query, key, value)

A unit-scaled dot-product attention function.

silu(input[, mult, constraint, inplace])

Applies a unit-scaled SiLU function.

silu_glu(input, gate[, mult])

Applies a unit-scaled gated linear unit for input * silu(gate).

softmax(input, dim[, dtype, constraint, mult])

Applies a unit-scaled softmax function.