Skip to content

ecnet.callbacks

Training callback objects/functions

Callback

Base Callback object

CallbackOperator

CallbackOperator: executes individual callback steps at each step

LRDecayLinear

__init__(self, init_lr, decay_rate, optimizer) special

Linear learning rate decay

Parameters:

Name Type Description Default
init_lr float

initial learning rate

required
decay_rate float

decay per epoch

required
optimizer torch.optim.Adam

optimizer used for training

required
Source code in ecnet/callbacks.py
def __init__(self, init_lr: float, decay_rate: float, optimizer):
    """
    Linear learning rate decay

    Args:
        init_lr (float): initial learning rate
        decay_rate (float): decay per epoch
        optimizer (torch.optim.Adam): optimizer used for training
    """
    super().__init__()
    self._init_lr = init_lr
    self._decay = decay_rate
    self.optimizer = optimizer

on_epoch_begin(self, epoch)

Training halted if: new learing rate == 0.0

Source code in ecnet/callbacks.py
def on_epoch_begin(self, epoch: int) -> bool:
    """
    Training halted if:
        new learing rate == 0.0
    """

    lr = max(0.0, self._init_lr - epoch * self._decay)
    if lr == 0.0:
        return False
    for g in self.optimizer.param_groups:
        g['lr'] = lr
    return True

Validator

__init__(self, loader, model, eval_iter, patience) special

Periodic validation using training data subset

Parameters:

Name Type Description Default
loader torch.utils.data.DataLoader

validation set

required
model ecnet.ECNet

model being trained

required
eval_iter int

validation set evaluated after this many epochs

required
patience int

if new lowest validation loss not found after this many epochs, terminate training, set model parameters to those observed @ lowest validation loss

required
Source code in ecnet/callbacks.py
def __init__(self, loader, model, eval_iter: int, patience: int):
    """
    Periodic validation using training data subset

    Args:
        loader (torch.utils.data.DataLoader): validation set
        model (ecnet.ECNet): model being trained
        eval_iter (int): validation set evaluated after `this` many epochs
        patience (int): if new lowest validation loss not found after `this` many epochs,
            terminate training, set model parameters to those observed @ lowest validation loss
    """

    super().__init__()
    self.loader = loader
    self.model = model
    self._ei = eval_iter
    self._best_loss = sys.maxsize
    self._most_recent_loss = sys.maxsize
    self._epoch_since_best = 0
    self.best_state = model.state_dict()
    self._patience = patience

on_epoch_end(self, epoch)

Training halted if: number of epochs since last lowest valid. MAE > specified patience

Source code in ecnet/callbacks.py
def on_epoch_end(self, epoch: int) -> bool:
    """
    Training halted if:
        number of epochs since last lowest valid. MAE > specified patience
    """

    if epoch % self._ei != 0:
        return True
    valid_loss = 0.0
    for batch in self.loader:
        v_pred = self.model(batch['desc_vals'])
        v_target = batch['target_val']
        v_loss = self.model.loss(v_pred, v_target)
        valid_loss += v_loss * len(batch['target_val'])
    valid_loss /= len(self.loader.dataset)
    self._most_recent_loss = valid_loss
    if valid_loss < self._best_loss:
        self._best_loss = valid_loss
        self.best_state = self.model.state_dict()
        self._epoch_since_best = 0
        return True
    self._epoch_since_best += self._ei
    if self._epoch_since_best > self._patience:
        return False
    return True

on_train_end(self)

After training, recall weights when lowest valid. MAE occurred

Source code in ecnet/callbacks.py
def on_train_end(self) -> bool:
    """
    After training, recall weights when lowest valid. MAE occurred
    """

    self.model.load_state_dict(self.best_state)
    return True