optimisers

Custom torch optimisers.

class LARS(params: Iterable[Any] | Dict[Any, Any], lr: float, momentum: float = 0.9, weight_decay: float = 0.0005, eta: float = 0.001, max_epoch: int = 200)

Implements layer-wise adaptive rate scaling for SGD.

Source: https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py

Parameters:
  • params (Iterable | dict) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – base learning rate (gamma_0)

  • momentum (float, optional) – momentum factor (default: 0) (β€œm”)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0) (”beta”)

  • eta (float, optional) – LARS coefficient

  • max_epoch (int) – maximum training epoch to determine polynomial LR decay.

Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. Large Batch Training of Convolutional Networks: https://arxiv.org/abs/1708.03888

Example

>>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
step(epoch: int | None = None, closure: Callable[[...], Any] | None = None)

Performs a single optimization step.

Parameters:
  • closure (callable, optional) – A closure that reevaluates the model and returns the loss.

  • epoch (int, optioanl) – Current epoch to calculate polynomial LR decay schedule. if None, uses self.epoch and increments it.