3.1.23.4. unit_scaling.optim.scaled_parameters

unit_scaling.optim.scaled_parameters(params: Iterable[Tensor] | Iterable[Dict[str, Any]], lr_scale_func: Callable[[ParameterData], float], lr: None | float | Tensor = None, weight_decay: float = 0, independent_weight_decay: bool = True, allow_non_unit_scaling_params: bool = False) Iterable[Tensor] | Iterable[Dict[str, Any]][source]

Create optimizer-appropriate lr-scaled parameter groups.

This method creates param_groups that apply the relevant scaling factors for u-muP models. For example:

torch.optim.Adam(uu.optim.scaled_parameters(
    model.parameters(), uu.optim.adam_lr_scale_func, lr=1.0
))
Parameters:
  • params (ParamsT) – an iterable of parameters of parameter groups, as passed to a torch optimizer.

  • lr_scale_func (Callable) – gets the optimizer-appropriate learning rate scale, based on a parameter tagged with mup_type and mup_scaling_depth. For example, lr_scale_func_sgd().

  • lr (float, optional) – global learning rate (overridden by groups).

  • weight_decay (float, optional) – weight decay value (overridden by groups).

  • independent_weight_decay (bool, optional) – enable lr-independent weight decay, which performs an update per-step that does not depend on lr.

  • allow_non_unit_scaling_params (bool, optional) – by default, this method fails if passed any regular non-unit-scaled params; set to True to disable this check.

Returns:

for passing on to the optimizer.

Return type:

ParamsT