in code/src/trainer.py [0:0]
def __init__(self, data, params):
"""
Initialize trainer.
"""
if hasattr(params, 'otf_num_processes'):
super().__init__(device_ids=tuple(range(params.otf_num_processes)))
# epoch size
if params.epoch_size == -1:
params.epoch_size = len(data['mono']['train'])
assert params.epoch_size > 0
# stopping criterion
if params.stop_crit == '':
self.crit = None
self.crit_best = None
else:
crit, dec_max = params.stop_crit.split(',')
assert crit[0] in ['_', '+'] and dec_max.isdigit()
sign = 1 if crit[0] == '+' else -1
self.crit = crit[1:]
self.crit_best = sign * -1e12
self.crit_sign = sign
self.decrease = 0
self.decrease_max = int(dec_max)
# validation metrics
self.metrics_best = {}
self.metrics_sign = {}
metrics = [m for m in params.metrics.split(',') if m != '']
for metric in metrics:
assert metric[0] in ['_', '+']
sign = 1 if metric[0] == '+' else -1
self.metrics_best[metric[1:]] = sign * -1e12
self.metrics_sign[metric[1:]] = sign
# periodic save with optional conditions
if params.save_periodic == '':
self.save_periodic_config = False
else:
split = params.save_periodic.split(',')
assert split[0].isdigit()
period = int(split[0])
conditions = [x.split(':') for x in split[1:]]
assert period >= 1
assert all([len(x) == 2 and len(x[0]) >= 1 and len(x[1]) >= 2 and x[1][0] in ['+', '-'] for x in conditions])
conditions = [(name, 1 if sign_value[0] == '+' else -1, float(sign_value[1:])) for name, sign_value in conditions]
self.save_periodic_config = (period, conditions)