def st_step()

in codegen_sources/model/src/trainer.py [0:0]


    def st_step(self, lang1, langs2, lambda_coeff, show_example=False):
        """
        Training on self-trained examples using unit tests
        """
        assert lambda_coeff >= 0
        if lambda_coeff == 0:
            return
        assert all([lang1 != lang2 and lang2 is not None for lang2 in langs2]), (
            lang1,
            langs2,
        )
        params = self.params

        lang1_id = params.lang2id[lang1]

        _encoder = self.encoder[0]

        if params.is_master and params.st_show_stats:
            for (l1, l2), cache in self.st_cache.items():
                logger.info(f"{l1}-{l2} cache size: {len(cache)}")
        dico = self.data["dico"]
        if 0 <= params.st_sample_cache_ratio < 1:
            read_from_cache = random.random() < params.st_sample_cache_ratio and all(
                [len(cache) >= params.cache_warmup for cache in self.st_cache.values()]
            )
        else:
            if self.number_consecutive_reads < params.st_sample_cache_ratio and all(
                [len(cache) >= params.cache_warmup for cache in self.st_cache.values()]
            ):
                read_from_cache = True
                self.number_consecutive_reads += 1
            else:
                read_from_cache = False
                self.number_consecutive_reads = 0

        if read_from_cache:
            if params.st_show_stats:
                logger.info(f"reading {params.st_sample_size} elements from the cache")
            for l1, l2 in [(l1, l2) for l1, l2 in self.st_langs]:
                (x1, len1), (x2, len2) = self.st_cache[(l1, l2)].sample_batch(
                    params.st_sample_size
                )
                if params.st_show_stats:
                    logger.info(f"actual batch size: {len(len2)}")
                x1, len1, x2, len2 = to_cuda(x1, len1, x2, len2)
                self.train_on_st_data(
                    x1,
                    len1,
                    l1,
                    x2,
                    len2,
                    l2,
                    dico,
                    params,
                    lambda_coeff,
                    show_example,
                    lang_src=lang1,
                )
                # number of processed sentences / words
                self.n_sentences += params.batch_size
                self.stats["processed_s"] += len1.size(0)
                self.stats["processed_w"] += (len1 - 1).sum().item()
                del x1, len1, x2, len2
        else:
            # generate source batch
            (x1, len1, id1, lenid1) = self.get_batch(
                "st",
                lang1,
                self_training=True,
                st_scores_cutoff=(
                    params.st_min_mutation_score,
                    self.params.st_min_asserts,
                ),
            )
            assert id1 is not None
            assert lenid1 is not None
            assert x1.shape[1] == len(len1) == id1.shape[1] == len(lenid1)
            sent_ids = convert_to_text(id1, lenid1, dico, params)
            sent_ids = [
                restore_segmentation_sentence(i, roberta_mode=params.roberta_mode)
                for i in sent_ids
            ]
            langs1 = x1.clone().fill_(lang1_id)

            # cuda
            x1, len1, langs1 = to_cuda(x1, len1, langs1)

            with torch.no_grad():

                # evaluation mode
                self.eval_mode()

                # encode source sentence and translate it
                enc1 = _encoder("fwd", x=x1, lengths=len1, langs=langs1, causal=False)
                enc1 = enc1.transpose(0, 1)

            # We generate data for every language in langs2 from the input in lang1
            generated_x2 = {}
            generated_x2_len = {}
            any_successful = {}
            for lang2 in langs2:
                (
                    selected_x1,
                    selected_len1,
                    x2,
                    len2,
                    any_successful_beam,
                ) = self.generate_parallel_examples(
                    x1, len1, enc1, lang1, lang2, sent_ids, params
                )
                if selected_x1 is None:
                    continue
                generated_x2[lang2] = x2
                generated_x2_len[lang2] = len2
                any_successful[lang2] = any_successful_beam

                self.train_on_st_data(
                    selected_x1,
                    selected_len1,
                    lang1,
                    generated_x2[lang2],
                    generated_x2_len[lang2],
                    lang2,
                    dico,
                    params,
                    lambda_coeff,
                    show_example,
                    lang_src=lang1,
                )

                # if needed, train on pairs of langs2 elements
                for lang2_2 in [
                    lang for lang in any_successful.keys() if lang != lang2
                ]:
                    x2, len2, x2_2, len2_2 = self.cross_language_st_selection(
                        generated_x2,
                        generated_x2_len,
                        any_successful,
                        lang2,
                        lang2_2,
                        params,
                    )
                    if x2 is None:
                        continue

                    self.train_on_st_data(
                        x2,
                        len2,
                        lang2,
                        x2_2,
                        len2_2,
                        lang2_2,
                        dico,
                        params,
                        lambda_coeff,
                        show_example,
                        lang_src=lang1,
                    )
            # number of processed sentences / words
            self.n_sentences += params.batch_size
            self.stats["processed_s"] += len1.size(0)
            self.stats["processed_w"] += (len1 - 1).sum().item()