in codegen_sources/model/src/trainer.py [0:0]
def __init__(self, data, params, model_names):
"""
Initialize trainer.
"""
# epoch / iteration size
self.params = params
self.data = data
self.MODEL_NAMES = model_names
self.epoch_size = params.epoch_size
if self.epoch_size == -1:
self.epoch_size = len(self.data)
assert self.epoch_size > 0
# data iterators
self.iterators = {}
# set parameters
self.set_parameters()
# float16 / distributed (no AMP)
assert params.amp >= 1 or not params.fp16
assert params.amp >= 0 or params.accumulate_gradients == 1
if params.multi_gpu and params.amp == -1:
logger.info("Using nn.parallel.DistributedDataParallel ...")
for name in self.MODEL_NAMES:
model_attr = getattr(self, name)
if isinstance(model_attr, list):
setattr(
self,
name,
[
CustomTorchDDP(
model,
device_ids=[params.local_rank],
output_device=params.local_rank,
broadcast_buffers=True,
)
for model in model_attr
],
)
else:
setattr(
self,
name,
CustomTorchDDP(
model_attr,
device_ids=[params.local_rank],
output_device=params.local_rank,
broadcast_buffers=True,
),
)
# set optimizers
self.set_optimizers()
# float16 / distributed (AMP)
if params.amp >= 0:
self.init_amp()
if params.multi_gpu:
logger.info("Using apex.parallel.DistributedDataParallel ...")
for name in self.MODEL_NAMES:
model_attr = getattr(self, name)
if isinstance(model_attr, list):
setattr(
self,
name,
[
CustomApexDDP(model, delay_allreduce=True)
for model in model_attr
],
)
else:
setattr(
self, name, CustomApexDDP(model_attr, delay_allreduce=True),
)
# stopping criterion used for early stopping
if params.stopping_criterion != "":
split = params.stopping_criterion.split(",")
assert len(split) == 2 and split[1].isdigit()
self.decrease_counts_max = int(split[1])
self.decrease_counts = 0
if split[0][0] == "_":
self.stopping_criterion = (split[0][1:], False)
else:
self.stopping_criterion = (split[0], True)
self.best_stopping_criterion = -1e12 if self.stopping_criterion[1] else 1e12
else:
self.stopping_criterion = None
self.best_stopping_criterion = None
if len(params.st_steps) > 0:
self.test_runners = {
"python": PythonTestRunner(timeout=params.st_test_timeout),
"cpp": CppTestRunner(timeout=params.st_test_timeout),
}
self.unit_tests = data[f"java_st_unit_tests"]
# probability of masking out / randomize / not modify words to predict
params.pred_probs = torch.FloatTensor(
[params.word_mask, params.word_keep, params.word_rand]
)
# probabilty to predict a word
counts = np.array(list(self.data["dico"].counts.values()))
params.mask_scores = np.maximum(counts, 1) ** -params.sample_alpha
params.mask_scores[params.pad_index] = 0 # do not predict <PAD> index
# do not predict special tokens
params.mask_scores[counts == 0] = 0
# validation metrics
self.metrics = []
metrics = [m for m in params.validation_metrics.split(",") if m != ""]
for m in metrics:
m = (m[1:], False) if m[0] == "_" else (m, True)
self.metrics.append(m)
self.best_metrics = {
metric: (-1e12 if biggest else 1e12) for (metric, biggest) in self.metrics
}
# training statistics
self.epoch = 0
self.n_iter = 0
self.n_total_iter = 0
self.n_sentences = 0
self.stats = OrderedDict(
[("processed_s", 0), ("processed_w", 0)]
+ [("CLM-%s" % l, []) for l in params.langs]
+ [("CLM-%s" % ("-".join(keys)), []) for keys in data["para"].keys()]
+ [("CLM-%s" % "-".join(keys[::-1]), []) for keys in data["para"].keys()]
+ [("MLM-%s" % l, []) for l in params.langs]
+ [("MLM-%s" % ("-".join(keys)), []) for keys in data["para"].keys()]
+ [("MLM-%s" % "-".join(keys[::-1]), []) for keys in data["para"].keys()]
+ [("AE-%s" % lang, []) for lang in params.ae_steps]
+ [("MT-%s-%s" % (l1, l2), []) for l1, l2 in params.mt_steps]
+ [
("MT-%s-%s-%s" % (l1, l2, span), [])
for l1, l2, span in params.mt_spans_steps
]
+ [("DO-%s-%s" % (l1, l2), []) for l1, l2 in params.do_steps]
+ [("Classif-%s-%s" % (l1, l2), []) for l1, l2 in params.classif_steps]
+ [("BT-%s-%s-%s" % (l1, l2, l3), []) for l1, l2, l3 in params.bt_steps]
+ [
("ST-%s:%s-%s" % (l1, l1, l2), [])
for l1, langs2 in params.st_steps
for l2 in langs2
]
+ [
("ST-%s:%s-%s" % (l1, l2, l1), [])
for l1, langs2 in params.st_steps
for l2 in langs2
]
+ [
("ST-%s:%s-%s" % (l1, l2_1, l2_2), [])
for l1, langs2 in params.st_steps
for l2_1 in langs2
for l2_2 in langs2
if l2_1 != l2_2
]
)
self.last_time = time.time()
self.st_langs = set()
for lang1, langs2 in params.st_steps:
for l1 in [lang1] + list(langs2):
for l2 in [lang1] + list(langs2):
if l1 < l2:
self.st_langs.add((l1, l2))
self.cache_class = RoundRobinCache if params.robin_cache else ListCache
self.st_cache = {
tuple([l1, l2]): self.cache_class(params=params) for l1, l2 in self.st_langs
}
self.number_consecutive_reads = 0
if params.cache_init_path != "":
self.load_initial_cache()
# reload potential checkpoints
self.reload_checkpoint()
# initialize lambda coefficients and their configurations
parse_lambda_config(params)