3.1.23. unit_scaling.optim

Optimizer wrappers that apply scaling rules for u-muP.

Provides Adam, AdamW, SGD as out-of-the-box optimizers.

Alternatively, scaled_parameters() provides finer control by transforming a parameter group for any downstream optimizer, given a function that defines the LR scaling rules.

Functions

lr_scale_for_depth(param)

Calculate the LR scaling factor for depth only.

lr_scale_func_adam(param)

Calculate the LR scaling factor for torch.optim.Adam and torch.optim.AdamW.

lr_scale_func_sgd(readout_constraint)

Calculate the LR scaling factor for torch.optim.SGD.

scaled_parameters(params, lr_scale_func[, ...])

Create optimizer-appropriate lr-scaled parameter groups.

Classes

Adam(params[, lr, weight_decay, ...])

An lr-scaled version of torch.optim.Adam for u-muP.

AdamW(params[, lr, weight_decay, ...])

An lr-scaled version of torch.optim.AdamW for u-muP.

SGD(params[, lr, weight_decay, ...])

An lr-scaled version of torch.optim.SGD for u-muP.`readout_constraint` should match the constraint arg used in LinearReadout.