def __init__()

in codegen_sources/model/src/trainer.py [0:0]


    def __init__(self, data, params, model_names):
        """
        Initialize trainer.
        """
        # epoch / iteration size
        self.params = params
        self.data = data
        self.MODEL_NAMES = model_names
        self.epoch_size = params.epoch_size
        if self.epoch_size == -1:
            self.epoch_size = len(self.data)
            assert self.epoch_size > 0

        # data iterators
        self.iterators = {}

        # set parameters
        self.set_parameters()

        # float16 / distributed (no AMP)
        assert params.amp >= 1 or not params.fp16
        assert params.amp >= 0 or params.accumulate_gradients == 1
        if params.multi_gpu and params.amp == -1:
            logger.info("Using nn.parallel.DistributedDataParallel ...")
            for name in self.MODEL_NAMES:
                model_attr = getattr(self, name)
                if isinstance(model_attr, list):
                    setattr(
                        self,
                        name,
                        [
                            CustomTorchDDP(
                                model,
                                device_ids=[params.local_rank],
                                output_device=params.local_rank,
                                broadcast_buffers=True,
                            )
                            for model in model_attr
                        ],
                    )
                else:
                    setattr(
                        self,
                        name,
                        CustomTorchDDP(
                            model_attr,
                            device_ids=[params.local_rank],
                            output_device=params.local_rank,
                            broadcast_buffers=True,
                        ),
                    )

        # set optimizers
        self.set_optimizers()

        # float16 / distributed (AMP)
        if params.amp >= 0:
            self.init_amp()
            if params.multi_gpu:
                logger.info("Using apex.parallel.DistributedDataParallel ...")
                for name in self.MODEL_NAMES:
                    model_attr = getattr(self, name)
                    if isinstance(model_attr, list):
                        setattr(
                            self,
                            name,
                            [
                                CustomApexDDP(model, delay_allreduce=True)
                                for model in model_attr
                            ],
                        )
                    else:
                        setattr(
                            self, name, CustomApexDDP(model_attr, delay_allreduce=True),
                        )

        # stopping criterion used for early stopping
        if params.stopping_criterion != "":
            split = params.stopping_criterion.split(",")
            assert len(split) == 2 and split[1].isdigit()
            self.decrease_counts_max = int(split[1])
            self.decrease_counts = 0
            if split[0][0] == "_":
                self.stopping_criterion = (split[0][1:], False)
            else:
                self.stopping_criterion = (split[0], True)
            self.best_stopping_criterion = -1e12 if self.stopping_criterion[1] else 1e12
        else:
            self.stopping_criterion = None
            self.best_stopping_criterion = None

        if len(params.st_steps) > 0:
            self.test_runners = {
                "python": PythonTestRunner(timeout=params.st_test_timeout),
                "cpp": CppTestRunner(timeout=params.st_test_timeout),
            }
            self.unit_tests = data[f"java_st_unit_tests"]

        # probability of masking out / randomize / not modify words to predict
        params.pred_probs = torch.FloatTensor(
            [params.word_mask, params.word_keep, params.word_rand]
        )

        # probabilty to predict a word
        counts = np.array(list(self.data["dico"].counts.values()))
        params.mask_scores = np.maximum(counts, 1) ** -params.sample_alpha
        params.mask_scores[params.pad_index] = 0  # do not predict <PAD> index
        # do not predict special tokens
        params.mask_scores[counts == 0] = 0

        # validation metrics
        self.metrics = []
        metrics = [m for m in params.validation_metrics.split(",") if m != ""]
        for m in metrics:
            m = (m[1:], False) if m[0] == "_" else (m, True)
            self.metrics.append(m)
        self.best_metrics = {
            metric: (-1e12 if biggest else 1e12) for (metric, biggest) in self.metrics
        }

        # training statistics
        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sentences = 0
        self.stats = OrderedDict(
            [("processed_s", 0), ("processed_w", 0)]
            + [("CLM-%s" % l, []) for l in params.langs]
            + [("CLM-%s" % ("-".join(keys)), []) for keys in data["para"].keys()]
            + [("CLM-%s" % "-".join(keys[::-1]), []) for keys in data["para"].keys()]
            + [("MLM-%s" % l, []) for l in params.langs]
            + [("MLM-%s" % ("-".join(keys)), []) for keys in data["para"].keys()]
            + [("MLM-%s" % "-".join(keys[::-1]), []) for keys in data["para"].keys()]
            + [("AE-%s" % lang, []) for lang in params.ae_steps]
            + [("MT-%s-%s" % (l1, l2), []) for l1, l2 in params.mt_steps]
            + [
                ("MT-%s-%s-%s" % (l1, l2, span), [])
                for l1, l2, span in params.mt_spans_steps
            ]
            + [("DO-%s-%s" % (l1, l2), []) for l1, l2 in params.do_steps]
            + [("Classif-%s-%s" % (l1, l2), []) for l1, l2 in params.classif_steps]
            + [("BT-%s-%s-%s" % (l1, l2, l3), []) for l1, l2, l3 in params.bt_steps]
            + [
                ("ST-%s:%s-%s" % (l1, l1, l2), [])
                for l1, langs2 in params.st_steps
                for l2 in langs2
            ]
            + [
                ("ST-%s:%s-%s" % (l1, l2, l1), [])
                for l1, langs2 in params.st_steps
                for l2 in langs2
            ]
            + [
                ("ST-%s:%s-%s" % (l1, l2_1, l2_2), [])
                for l1, langs2 in params.st_steps
                for l2_1 in langs2
                for l2_2 in langs2
                if l2_1 != l2_2
            ]
        )
        self.last_time = time.time()
        self.st_langs = set()
        for lang1, langs2 in params.st_steps:
            for l1 in [lang1] + list(langs2):
                for l2 in [lang1] + list(langs2):
                    if l1 < l2:
                        self.st_langs.add((l1, l2))
        self.cache_class = RoundRobinCache if params.robin_cache else ListCache
        self.st_cache = {
            tuple([l1, l2]): self.cache_class(params=params) for l1, l2 in self.st_langs
        }
        self.number_consecutive_reads = 0
        if params.cache_init_path != "":
            self.load_initial_cache()
        # reload potential checkpoints
        self.reload_checkpoint()

        # initialize lambda coefficients and their configurations
        parse_lambda_config(params)