in src/trainer.py [0:0]
def __init__(self, model, params, ftmodel=None, teacher_model=None):
"""
Initialize trainer.
"""
# pretrained model / model / params
self.model = model
self.ftmodel = ftmodel
self.teacher_model = teacher_model
self.params = params
assert params.fp16 is False
# set parameters
self.set_parameters()
# set optimizers
self.set_optimizers()
# stopping criterion used for early stopping
if params.stopping_criterion != '':
split = params.stopping_criterion.split(',')
assert len(split) == 2 and split[1].isdigit()
self.decrease_counts_max = int(split[1])
self.decrease_counts = 0
if split[0][0] == '_':
self.stopping_criterion = (split[0][1:], False)
else:
self.stopping_criterion = (split[0], True)
self.best_stopping_criterion = -1e12 if self.stopping_criterion[1] else 1e12
else:
self.stopping_criterion = None
self.best_stopping_criterion = None
# validation metrics
self.metrics = []
metrics = [m for m in params.validation_metrics.split(',') if m != '']
for m in metrics:
m = (m[1:], False) if m[0] == '_' else (m, True)
self.metrics.append(m)
self.best_metrics = {metric: (-1e12 if biggest else 1e12) for (metric, biggest) in self.metrics}
# training statistics
self.epoch = 0
self.indices = []
self.n_iter = 0
self.embeddings = None
self.stats = OrderedDict(
[('processed_i', 0)] +
[('MSE', [])] +
[('XE', [])] +
[('triplet', [])] +
[('time', [])]
)
self.last_time = time.time()