in XLM/src/trainer.py [0:0]
def __init__(self, data, params):
"""
Initialize trainer.
"""
self.tb_writer = SummaryWriter(params.dump_path) if params.global_rank in [-1, 0] else None
# epoch / iteration size
self.epoch_size = params.epoch_size
if self.epoch_size == -1:
self.epoch_size = self.data
assert self.epoch_size > 0
# data iterators
self.iterators = {}
# list memory components
self.memory_list = []
self.ffn_list = []
for name in self.MODEL_NAMES:
find_modules(getattr(self, name), f'self.{name}', HashingMemory, self.memory_list)
find_modules(getattr(self, name), f'self.{name}', TransformerFFN, self.ffn_list)
logger.info("Found %i memories." % len(self.memory_list))
logger.info("Found %i FFN." % len(self.ffn_list))
# set parameters
self.set_parameters()
# float16 / distributed (no AMP)
assert params.amp >= 1 or not params.fp16
assert params.amp >= 0 or params.accumulate_gradients == 1
if params.multi_gpu and params.amp == -1:
logger.info("Using nn.parallel.DistributedDataParallel ...")
for name in self.MODEL_NAMES:
setattr(self, name, nn.parallel.DistributedDataParallel(getattr(self, name), device_ids=[params.local_rank], output_device=params.local_rank, broadcast_buffers=True))
# set optimizers
self.set_optimizers()
# float16 / distributed (AMP)
if params.amp >= 0:
self.init_amp()
if params.multi_gpu:
logger.info("Using apex.parallel.DistributedDataParallel ...")
for name in self.MODEL_NAMES:
setattr(self, name, apex.parallel.DistributedDataParallel(getattr(self, name), delay_allreduce=True))
# stopping criterion used for early stopping
if params.stopping_criterion != '':
split = params.stopping_criterion.split(',')
assert len(split) == 2 and split[1].isdigit()
self.decrease_counts_max = int(split[1])
self.decrease_counts = 0
if split[0][0] == '_':
self.stopping_criterion = (split[0][1:], False)
else:
self.stopping_criterion = (split[0], True)
self.best_stopping_criterion = -1e12 if self.stopping_criterion[1] else 1e12
else:
self.stopping_criterion = None
self.best_stopping_criterion = None
# probability of masking out / randomize / not modify words to predict
params.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
# probabilty to predict a word
counts = np.array(list(self.data['dico'].counts.values()))
params.mask_scores = np.maximum(counts, 1) ** -params.sample_alpha
params.mask_scores[params.pad_index] = 0 # do not predict <PAD> index
params.mask_scores[counts == 0] = 0 # do not predict special tokens
# validation metrics
self.metrics = []
metrics = [m for m in params.validation_metrics.split(',') if m != '']
for m in metrics:
m = (m[1:], False) if m[0] == '_' else (m, True)
self.metrics.append(m)
self.best_metrics = {metric: (-1e12 if biggest else 1e12) for (metric, biggest) in self.metrics}
# training statistics
self.epoch = 0
self.n_iter = 0
self.n_total_iter = 0
self.n_sentences = 0
self.stats = OrderedDict(
[('processed_s', 0), ('processed_w', 0)] +
[('CLM-%s' % l, []) for l in params.langs] +
[('CLM-%s-%s' % (l1, l2), []) for l1, l2 in data['para'].keys()] +
[('CLM-%s-%s' % (l2, l1), []) for l1, l2 in data['para'].keys()] +
[('MLM-%s' % l, []) for l in params.langs] +
[('MLM-%s-%s' % (l1, l2), []) for l1, l2 in data['para'].keys()] +
[('MLM-%s-%s' % (l2, l1), []) for l1, l2 in data['para'].keys()] +
[('PC-%s-%s' % (l1, l2), []) for l1, l2 in params.pc_steps] +
[('AE-%s' % lang, []) for lang in params.ae_steps] +
[('MT-%s-%s' % (l1, l2), []) for l1, l2 in params.mt_steps] +
[('BT-%s-%s-%s' % (l1, l2, l3), []) for l1, l2, l3 in params.bt_steps]
)
self.last_time = time.time()
# reload potential checkpoints
self.reload_checkpoint()
# initialize lambda coefficients and their configurations
parse_lambda_config(params)