optimisersο
Custom torch
optimisers.
- class LARS(params: Iterable[Any] | dict[Any, Any], lr: float, momentum: float = 0.9, weight_decay: float = 0.0005, eta: float = 0.001, max_epoch: int = 200)ο
Implements layer-wise adaptive rate scaling for SGD.
Source: https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py
- Parameters:
params (Iterable | dict) β Iterable of parameters to optimize or dicts defining parameter groups
lr (float) β base learning rate (gamma_0)
momentum (float, optional) β momentum factor (default: 0) (βmβ)
weight_decay (float, optional) β weight decay (L2 penalty) (default: 0) (βbetaβ)
eta (float, optional) β LARS coefficient
max_epoch (int) β maximum training epoch to determine polynomial LR decay.
Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. Large Batch Training of Convolutional Networks: https://arxiv.org/abs/1708.03888
Example
>>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
- step(epoch: int | None = None, closure: Callable[[...], Any] | None = None)ο
Performs a single optimization step.
- Parameters:
closure (callable, optional) β A closure that reevaluates the model and returns the loss.
epoch (int, optioanl) β Current epoch to calculate polynomial LR decay schedule. if None, uses self.epoch and increments it.