Skip to content

ecnet.callbacks

ecnet.callbacks.CallbackOperator

Bases: object

CallbackOperator: executes individual callback steps at each step

Source code in ecnet/callbacks.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class CallbackOperator(object):
    """
    CallbackOperator: executes individual callback steps at each step
    """

    def __init__(self):

        self.cb = []

    def add_cb(self, cb):

        self.cb.append(cb)

    def on_train_begin(self):

        for cb in self.cb:
            if not cb.on_train_begin():
                return False
        return True

    def on_train_end(self):

        for cb in self.cb:
            if not cb.on_train_end():
                return False
        return True

    def on_epoch_begin(self, epoch):

        for cb in self.cb:
            if not cb.on_epoch_begin(epoch):
                return False
        return True

    def on_epoch_end(self, epoch):

        for cb in self.cb:
            if not cb.on_epoch_end(epoch):
                return False
        return True

    def on_batch_begin(self, batch):

        for cb in self.cb:
            if not cb.on_batch_begin(batch):
                return False
        return True

    def on_batch_end(self, batch):

        for cb in self.cb:
            if not cb.on_batch_end(batch):
                return False
        return True

    def on_loss_begin(self, batch):

        for cb in self.cb:
            if not cb.on_loss_begin(batch):
                return False
        return True

    def on_loss_end(self, batch):

        for cb in self.cb:
            if not cb.on_loss_end(batch):
                return False
        return True

    def on_step_begin(self, batch):

        for cb in self.cb:
            if not cb.on_step_begin(batch):
                return False
        return True

    def on_step_end(self, batch):

        for cb in self.cb:
            if not cb.on_step_end(batch):
                return False
        return True

ecnet.callbacks.Callback

Bases: object

Base Callback object

Source code in ecnet/callbacks.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class Callback(object):
    """
    Base Callback object
    """

    def __init__(self): pass
    def on_train_begin(self): return True
    def on_train_end(self): return True
    def on_epoch_begin(self, epoch): return True
    def on_epoch_end(self, epoch): return True
    def on_batch_begin(self, batch): return True
    def on_batch_end(self, batch): return True
    def on_loss_begin(self, batch): return True
    def on_loss_end(self, batch): return True
    def on_step_begin(self, batch): return True
    def on_step_end(self, batch): return True

ecnet.callbacks.LRDecayLinear

Bases: Callback

Source code in ecnet/callbacks.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class LRDecayLinear(Callback):

    def __init__(self, init_lr: float, decay_rate: float, optimizer):
        """
        Linear learning rate decay

        Args:
            init_lr (float): initial learning rate
            decay_rate (float): decay per epoch
            optimizer (torch.optim.Adam): optimizer used for training
        """
        super().__init__()
        self._init_lr = init_lr
        self._decay = decay_rate
        self.optimizer = optimizer

    def on_epoch_begin(self, epoch: int) -> bool:
        """
        Training halted if:
            new learing rate == 0.0
        """

        lr = max(0.0, self._init_lr - epoch * self._decay)
        if lr == 0.0:
            return False
        for g in self.optimizer.param_groups:
            g['lr'] = lr
        return True

__init__(init_lr, decay_rate, optimizer)

Linear learning rate decay

Parameters:

Name Type Description Default
init_lr float

initial learning rate

required
decay_rate float

decay per epoch

required
optimizer torch.optim.Adam

optimizer used for training

required
Source code in ecnet/callbacks.py
109
110
111
112
113
114
115
116
117
118
119
120
121
def __init__(self, init_lr: float, decay_rate: float, optimizer):
    """
    Linear learning rate decay

    Args:
        init_lr (float): initial learning rate
        decay_rate (float): decay per epoch
        optimizer (torch.optim.Adam): optimizer used for training
    """
    super().__init__()
    self._init_lr = init_lr
    self._decay = decay_rate
    self.optimizer = optimizer

on_epoch_begin(epoch)

Training halted if

new learing rate == 0.0

Source code in ecnet/callbacks.py
123
124
125
126
127
128
129
130
131
132
133
134
def on_epoch_begin(self, epoch: int) -> bool:
    """
    Training halted if:
        new learing rate == 0.0
    """

    lr = max(0.0, self._init_lr - epoch * self._decay)
    if lr == 0.0:
        return False
    for g in self.optimizer.param_groups:
        g['lr'] = lr
    return True

ecnet.callbacks.Validator

Bases: Callback

Source code in ecnet/callbacks.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class Validator(Callback):

    def __init__(self, loader, model, eval_iter: int, patience: int):
        """
        Periodic validation using training data subset

        Args:
            loader (torch.utils.data.DataLoader): validation set
            model (ecnet.ECNet): model being trained
            eval_iter (int): validation set evaluated after `this` many epochs
            patience (int): if new lowest validation loss not found after `this` many epochs,
                terminate training, set model parameters to those observed @ lowest validation loss
        """

        super().__init__()
        self.loader = loader
        self.model = model
        self._ei = eval_iter
        self._best_loss = sys.maxsize
        self._most_recent_loss = sys.maxsize
        self._epoch_since_best = 0
        self.best_state = model.state_dict()
        self._patience = patience

    def on_epoch_end(self, epoch: int) -> bool:
        """
        Training halted if:
            number of epochs since last lowest valid. MAE > specified patience
        """

        if epoch % self._ei != 0:
            return True
        valid_loss = 0.0
        for batch in self.loader:
            v_pred = self.model(batch['desc_vals'])
            v_target = batch['target_val']
            v_loss = self.model.loss(v_pred, v_target)
            valid_loss += v_loss * len(batch['target_val'])
        valid_loss /= len(self.loader.dataset)
        self._most_recent_loss = valid_loss
        if valid_loss < self._best_loss:
            self._best_loss = valid_loss
            self.best_state = self.model.state_dict()
            self._epoch_since_best = 0
            return True
        self._epoch_since_best += self._ei
        if self._epoch_since_best > self._patience:
            return False
        return True

    def on_train_end(self) -> bool:
        """
        After training, recall weights when lowest valid. MAE occurred
        """

        self.model.load_state_dict(self.best_state)
        return True

__init__(loader, model, eval_iter, patience)

Periodic validation using training data subset

Parameters:

Name Type Description Default
loader torch.utils.data.DataLoader

validation set

required
model ecnet.ECNet

model being trained

required
eval_iter int

validation set evaluated after this many epochs

required
patience int

if new lowest validation loss not found after this many epochs, terminate training, set model parameters to those observed @ lowest validation loss

required
Source code in ecnet/callbacks.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def __init__(self, loader, model, eval_iter: int, patience: int):
    """
    Periodic validation using training data subset

    Args:
        loader (torch.utils.data.DataLoader): validation set
        model (ecnet.ECNet): model being trained
        eval_iter (int): validation set evaluated after `this` many epochs
        patience (int): if new lowest validation loss not found after `this` many epochs,
            terminate training, set model parameters to those observed @ lowest validation loss
    """

    super().__init__()
    self.loader = loader
    self.model = model
    self._ei = eval_iter
    self._best_loss = sys.maxsize
    self._most_recent_loss = sys.maxsize
    self._epoch_since_best = 0
    self.best_state = model.state_dict()
    self._patience = patience

on_epoch_end(epoch)

Training halted if

number of epochs since last lowest valid. MAE > specified patience

Source code in ecnet/callbacks.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def on_epoch_end(self, epoch: int) -> bool:
    """
    Training halted if:
        number of epochs since last lowest valid. MAE > specified patience
    """

    if epoch % self._ei != 0:
        return True
    valid_loss = 0.0
    for batch in self.loader:
        v_pred = self.model(batch['desc_vals'])
        v_target = batch['target_val']
        v_loss = self.model.loss(v_pred, v_target)
        valid_loss += v_loss * len(batch['target_val'])
    valid_loss /= len(self.loader.dataset)
    self._most_recent_loss = valid_loss
    if valid_loss < self._best_loss:
        self._best_loss = valid_loss
        self.best_state = self.model.state_dict()
        self._epoch_since_best = 0
        return True
    self._epoch_since_best += self._ei
    if self._epoch_since_best > self._patience:
        return False
    return True

on_train_end()

After training, recall weights when lowest valid. MAE occurred

Source code in ecnet/callbacks.py
187
188
189
190
191
192
193
def on_train_end(self) -> bool:
    """
    After training, recall weights when lowest valid. MAE occurred
    """

    self.model.load_state_dict(self.best_state)
    return True