3.1.23.4. unit_scaling.optim.scaled_parameters
- unit_scaling.optim.scaled_parameters(params: Iterable[Tensor] | Iterable[Dict[str, Any]], lr_scale_func: Callable[[ParameterData], float], lr: None | float | Tensor = None, weight_decay: float = 0, independent_weight_decay: bool = True, allow_non_unit_scaling_params: bool = False) Iterable[Tensor] | Iterable[Dict[str, Any]] [source]
Create optimizer-appropriate lr-scaled parameter groups.
This method creates param_groups that apply the relevant scaling factors for u-muP models. For example:
torch.optim.Adam(uu.optim.scaled_parameters( model.parameters(), uu.optim.adam_lr_scale_func, lr=1.0 ))
- Parameters:
params (ParamsT) – an iterable of parameters of parameter groups, as passed to a torch optimizer.
lr_scale_func (Callable) – gets the optimizer-appropriate learning rate scale, based on a parameter tagged with mup_type and mup_scaling_depth. For example,
lr_scale_func_sgd()
.lr (float, optional) – global learning rate (overridden by groups).
weight_decay (float, optional) – weight decay value (overridden by groups).
independent_weight_decay (bool, optional) – enable lr-independent weight decay, which performs an update per-step that does not depend on lr.
allow_non_unit_scaling_params (bool, optional) – by default, this method fails if passed any regular non-unit-scaled params; set to True to disable this check.
- Returns:
for passing on to the optimizer.
- Return type:
ParamsT