ecnet.callbacks
Training callback objects/functions
Callback
Base Callback object
CallbackOperator
CallbackOperator: executes individual callback steps at each step
LRDecayLinear
__init__(self, init_lr, decay_rate, optimizer)
special
Linear learning rate decay
Parameters:
Name | Type | Description | Default |
---|---|---|---|
init_lr |
float |
initial learning rate |
required |
decay_rate |
float |
decay per epoch |
required |
optimizer |
torch.optim.Adam |
optimizer used for training |
required |
Source code in ecnet/callbacks.py
def __init__(self, init_lr: float, decay_rate: float, optimizer):
"""
Linear learning rate decay
Args:
init_lr (float): initial learning rate
decay_rate (float): decay per epoch
optimizer (torch.optim.Adam): optimizer used for training
"""
super().__init__()
self._init_lr = init_lr
self._decay = decay_rate
self.optimizer = optimizer
on_epoch_begin(self, epoch)
Training halted if: new learing rate == 0.0
Source code in ecnet/callbacks.py
def on_epoch_begin(self, epoch: int) -> bool:
"""
Training halted if:
new learing rate == 0.0
"""
lr = max(0.0, self._init_lr - epoch * self._decay)
if lr == 0.0:
return False
for g in self.optimizer.param_groups:
g['lr'] = lr
return True
Validator
__init__(self, loader, model, eval_iter, patience)
special
Periodic validation using training data subset
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loader |
torch.utils.data.DataLoader |
validation set |
required |
model |
ecnet.ECNet |
model being trained |
required |
eval_iter |
int |
validation set evaluated after |
required |
patience |
int |
if new lowest validation loss not found after |
required |
Source code in ecnet/callbacks.py
def __init__(self, loader, model, eval_iter: int, patience: int):
"""
Periodic validation using training data subset
Args:
loader (torch.utils.data.DataLoader): validation set
model (ecnet.ECNet): model being trained
eval_iter (int): validation set evaluated after `this` many epochs
patience (int): if new lowest validation loss not found after `this` many epochs,
terminate training, set model parameters to those observed @ lowest validation loss
"""
super().__init__()
self.loader = loader
self.model = model
self._ei = eval_iter
self._best_loss = sys.maxsize
self._most_recent_loss = sys.maxsize
self._epoch_since_best = 0
self.best_state = model.state_dict()
self._patience = patience
on_epoch_end(self, epoch)
Training halted if: number of epochs since last lowest valid. MAE > specified patience
Source code in ecnet/callbacks.py
def on_epoch_end(self, epoch: int) -> bool:
"""
Training halted if:
number of epochs since last lowest valid. MAE > specified patience
"""
if epoch % self._ei != 0:
return True
valid_loss = 0.0
for batch in self.loader:
v_pred = self.model(batch['desc_vals'])
v_target = batch['target_val']
v_loss = self.model.loss(v_pred, v_target)
valid_loss += v_loss * len(batch['target_val'])
valid_loss /= len(self.loader.dataset)
self._most_recent_loss = valid_loss
if valid_loss < self._best_loss:
self._best_loss = valid_loss
self.best_state = self.model.state_dict()
self._epoch_since_best = 0
return True
self._epoch_since_best += self._ei
if self._epoch_since_best > self._patience:
return False
return True
on_train_end(self)
After training, recall weights when lowest valid. MAE occurred
Source code in ecnet/callbacks.py
def on_train_end(self) -> bool:
"""
After training, recall weights when lowest valid. MAE occurred
"""
self.model.load_state_dict(self.best_state)
return True