Source code for galax.nn.utils.early_stopping

[docs]class EarlyStopping(object): """Early stopping. Parameters ---------- patience : int = 10 Patience for early stopping. """ best_losses = None params = None counter = 0
[docs] def __init__(self, patience: int = 10): self.patience = patience
def __call__(self, losses, params): if self.best_losses is None: self.best_losses = losses self.counter = 0 elif any( loss <= best_loss for loss, best_loss in zip(losses, self.best_losses) ): if all( loss <= best_loss for loss, best_loss in zip(losses, self.best_losses) ): self.params = params self.best_losses = [ min(loss, best_loss) for loss, best_loss in zip(losses, self.best_losses) ] self.counter = 0 else: self.counter += 1 if self.counter == self.patience: return True return False