in codegen_sources/model/src/trainer.py [0:0]
def st_step(self, lang1, langs2, lambda_coeff, show_example=False):
"""
Training on self-trained examples using unit tests
"""
assert lambda_coeff >= 0
if lambda_coeff == 0:
return
assert all([lang1 != lang2 and lang2 is not None for lang2 in langs2]), (
lang1,
langs2,
)
params = self.params
lang1_id = params.lang2id[lang1]
_encoder = self.encoder[0]
if params.is_master and params.st_show_stats:
for (l1, l2), cache in self.st_cache.items():
logger.info(f"{l1}-{l2} cache size: {len(cache)}")
dico = self.data["dico"]
if 0 <= params.st_sample_cache_ratio < 1:
read_from_cache = random.random() < params.st_sample_cache_ratio and all(
[len(cache) >= params.cache_warmup for cache in self.st_cache.values()]
)
else:
if self.number_consecutive_reads < params.st_sample_cache_ratio and all(
[len(cache) >= params.cache_warmup for cache in self.st_cache.values()]
):
read_from_cache = True
self.number_consecutive_reads += 1
else:
read_from_cache = False
self.number_consecutive_reads = 0
if read_from_cache:
if params.st_show_stats:
logger.info(f"reading {params.st_sample_size} elements from the cache")
for l1, l2 in [(l1, l2) for l1, l2 in self.st_langs]:
(x1, len1), (x2, len2) = self.st_cache[(l1, l2)].sample_batch(
params.st_sample_size
)
if params.st_show_stats:
logger.info(f"actual batch size: {len(len2)}")
x1, len1, x2, len2 = to_cuda(x1, len1, x2, len2)
self.train_on_st_data(
x1,
len1,
l1,
x2,
len2,
l2,
dico,
params,
lambda_coeff,
show_example,
lang_src=lang1,
)
# number of processed sentences / words
self.n_sentences += params.batch_size
self.stats["processed_s"] += len1.size(0)
self.stats["processed_w"] += (len1 - 1).sum().item()
del x1, len1, x2, len2
else:
# generate source batch
(x1, len1, id1, lenid1) = self.get_batch(
"st",
lang1,
self_training=True,
st_scores_cutoff=(
params.st_min_mutation_score,
self.params.st_min_asserts,
),
)
assert id1 is not None
assert lenid1 is not None
assert x1.shape[1] == len(len1) == id1.shape[1] == len(lenid1)
sent_ids = convert_to_text(id1, lenid1, dico, params)
sent_ids = [
restore_segmentation_sentence(i, roberta_mode=params.roberta_mode)
for i in sent_ids
]
langs1 = x1.clone().fill_(lang1_id)
# cuda
x1, len1, langs1 = to_cuda(x1, len1, langs1)
with torch.no_grad():
# evaluation mode
self.eval_mode()
# encode source sentence and translate it
enc1 = _encoder("fwd", x=x1, lengths=len1, langs=langs1, causal=False)
enc1 = enc1.transpose(0, 1)
# We generate data for every language in langs2 from the input in lang1
generated_x2 = {}
generated_x2_len = {}
any_successful = {}
for lang2 in langs2:
(
selected_x1,
selected_len1,
x2,
len2,
any_successful_beam,
) = self.generate_parallel_examples(
x1, len1, enc1, lang1, lang2, sent_ids, params
)
if selected_x1 is None:
continue
generated_x2[lang2] = x2
generated_x2_len[lang2] = len2
any_successful[lang2] = any_successful_beam
self.train_on_st_data(
selected_x1,
selected_len1,
lang1,
generated_x2[lang2],
generated_x2_len[lang2],
lang2,
dico,
params,
lambda_coeff,
show_example,
lang_src=lang1,
)
# if needed, train on pairs of langs2 elements
for lang2_2 in [
lang for lang in any_successful.keys() if lang != lang2
]:
x2, len2, x2_2, len2_2 = self.cross_language_st_selection(
generated_x2,
generated_x2_len,
any_successful,
lang2,
lang2_2,
params,
)
if x2 is None:
continue
self.train_on_st_data(
x2,
len2,
lang2,
x2_2,
len2_2,
lang2_2,
dico,
params,
lambda_coeff,
show_example,
lang_src=lang1,
)
# number of processed sentences / words
self.n_sentences += params.batch_size
self.stats["processed_s"] += len1.size(0)
self.stats["processed_w"] += (len1 - 1).sum().item()