pytorchtools

Module containing EarlyStopping to track when the training of a model should stop.

Source: https://github.com/Bjarten/early-stopping-pytorch

class EarlyStopping(patience: int = 7, verbose: bool = False, delta: float = 0.0, path: str | ~pathlib.Path = 'checkpoint.pt', trace_func: ~typing.Callable[[...], None] = <built-in function print>)

Early stops the training if validation loss doesn’t improve after a given patience.

patience

How long to wait after last time validation loss improved.

Type:

int

verbose

If True, prints a message for each validation loss improvement.

Type:

bool

counter

Number of epochs of worsening validation loss since last improvement.

Type:

int

best_score

Best validation loss score recorded.

Type:

float

early_stop

Will be True if early stopping is triggered by patience number of validation epochs with worsening validation losses consecutively.

Type:

bool

val_loss_min

The lowest validation loss recorded.

Type:

float

delta

Minimum change in the monitored quantity to qualify as an improvement.

Type:

float

path

Path for the checkpoint to be saved to.

Type:

str

trace_func

Trace print function.

Type:

Callable[…, None]

Parameters:
  • patience (int) – How long to wait after last time validation loss improved. Default: 7

  • verbose (bool) – If True, prints a message for each validation loss improvement. Default: False

  • delta (float) – Minimum change in the monitored quantity to qualify as an improvement. Default: 0

  • path (str) – Path for the checkpoint to be saved to. Default: 'checkpoint.pt'

  • trace_func (Callable[..., None]) – Trace print function. Default: print()

save_checkpoint(val_loss: float, model: Module) None

Saves model when validation loss decrease.

Parameters:
  • val_loss (float) – Validation loss.

  • model (Module) – The model to save checkpoint of.