in XLM/src/trainer.py [0:0]
def __init__(self, data, params):
"""
Initialize trainer.
"""
# epoch / iteration size
self.epoch_size = params.epoch_size
if self.epoch_size == -1:
self.epoch_size = self.data
assert self.epoch_size > 0
# data iterators
self.iterators = {}
# set parameters
self.set_parameters()
# float16 / distributed (no AMP)
assert params.amp >= 1 or not params.fp16
assert params.amp >= 0 or params.accumulate_gradients == 1
if params.multi_gpu and params.amp == -1:
logger.info("Using nn.parallel.DistributedDataParallel ...")
for name in self.MODEL_NAMES:
setattr(self, name, [nn.parallel.DistributedDataParallel(model, device_ids=[
params.local_rank], output_device=params.local_rank, broadcast_buffers=True) for model in getattr(self, name)])
# set optimizers
self.set_optimizers()
# float16 / distributed (AMP)
if params.amp >= 0:
self.init_amp()
if params.multi_gpu:
logger.info("Using apex.parallel.DistributedDataParallel ...")
for name in self.MODEL_NAMES:
setattr(self, name, [apex.parallel.DistributedDataParallel(
model, delay_allreduce=True) for model in getattr(self, name)])
# stopping criterion used for early stopping
if params.stopping_criterion != '':
split = params.stopping_criterion.split(',')
assert len(split) == 2 and split[1].isdigit()
self.decrease_counts_max = int(split[1])
self.decrease_counts = 0
if split[0][0] == '_':
self.stopping_criterion = (split[0][1:], False)
else:
self.stopping_criterion = (split[0], True)
self.best_stopping_criterion = - \
1e12 if self.stopping_criterion[1] else 1e12
else:
self.stopping_criterion = None
self.best_stopping_criterion = None
# probability of masking out / randomize / not modify words to predict
params.pred_probs = torch.FloatTensor(
[params.word_mask, params.word_keep, params.word_rand])
# probabilty to predict a word
counts = np.array(list(self.data['dico'].counts.values()))
params.mask_scores = np.maximum(counts, 1) ** -params.sample_alpha
params.mask_scores[params.pad_index] = 0 # do not predict <PAD> index
# do not predict special tokens
params.mask_scores[counts == 0] = 0
# validation metrics
self.metrics = []
metrics = [m for m in params.validation_metrics.split(',') if m != '']
for m in metrics:
m = (m[1:], False) if m[0] == '_' else (m, True)
self.metrics.append(m)
self.best_metrics = {metric: (-1e12 if biggest else 1e12)
for (metric, biggest) in self.metrics}
# training statistics
self.epoch = 0
self.n_iter = 0
self.n_total_iter = 0
self.n_sentences = 0
self.stats = OrderedDict(
[('processed_s', 0), ('processed_w', 0)] +
[('CLM-%s' % l, []) for l in params.langs] +
[('CLM-%s-%s' % (l1, l2), []) for l1, l2 in data['para'].keys()] +
[('CLM-%s-%s' % (l2, l1), []) for l1, l2 in data['para'].keys()] +
[('MLM-%s' % l, []) for l in params.langs] +
[('MLM-%s-%s' % (l1, l2), []) for l1, l2 in data['para'].keys()] +
[('MLM-%s-%s' % (l2, l1), []) for l1, l2 in data['para'].keys()] +
[('AE-%s' % lang, []) for lang in params.ae_steps] +
[('MT-%s-%s' % (l1, l2), []) for l1, l2 in params.mt_steps] +
[('BT-%s-%s-%s' % (l1, l2, l3), [])
for l1, l2, l3 in params.bt_steps]
)
self.last_time = time.time()
# reload potential checkpoints
self.reload_checkpoint()
# initialize lambda coefficients and their configurations
parse_lambda_config(params)