def _init_state()

in mdr/qa/qa_trainer.py [0:0]


    def _init_state(self) -> None:
        """
        Initialize the state and load it from an existing checkpoint if any
        """
        job_env = submitit.JobEnvironment()

        if job_env.global_rank == 0:
            # config_path = Path(args.save_folder) / str(job_env.job_id) / 'config.json'
            os.makedirs(self._train_cfg.output_dir, exist_ok=True)
            config_path = Path(self._train_cfg.output_dir)  / 'config.json'
            with open(config_path, "w") as g:
                g.write(json.dumps(self._train_cfg._asdict()))

        print(f"Setting random seed {self._train_cfg.seed}", flush=True)
        random.seed(self._train_cfg.seed)
        np.random.seed(self._train_cfg.seed)
        torch.manual_seed(self._train_cfg.seed)

        print("Create data loaders", flush=True)
        tokenizer = AutoTokenizer.from_pretrained(self._train_cfg.model_name)
        collate_fc = partial(rank_collate, pad_id=tokenizer.pad_token_id)
        train_set = RankingDataset(tokenizer, self._train_cfg.train_file, self._train_cfg.max_seq_len, self._train_cfg.max_q_len, train=True)

        train_sampler = MhopSampler(train_set, num_neg=self._train_cfg.neg_num)

        batch_size_per_gpu = (1 + self._train_cfg.neg_num) * self._train_cfg.num_q_per_gpu
        n_gpu = torch.cuda.device_count()
        print(f"Number of GPUs: {n_gpu}", flush=True)
        print(f"Batch size per node: {batch_size_per_gpu * n_gpu}", flush=True)

        self._train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size_per_gpu * n_gpu, num_workers=self._train_cfg.num_workers, collate_fn=collate_fc, sampler=train_sampler)
        test_set = RankingDataset(tokenizer, self._train_cfg.predict_file, self._train_cfg.max_seq_len, self._train_cfg.max_q_len)
        self._test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=self._train_cfg.predict_batch_size,
            num_workers=self._train_cfg.num_workers, collate_fn=collate_fc
        )

        print("Create model", flush=True)
        print(f"Local rank {job_env.local_rank}", flush=True)
        bert_config = AutoConfig.from_pretrained(self._train_cfg.model_name)
        model = QAModel(bert_config, self._train_cfg)
        model.cuda(job_env.local_rank)

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(
                nd in n for nd in no_decay)], 'weight_decay': self._train_cfg.weight_decay},
            {'params': [p for n, p in model.named_parameters() if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        if self._train_cfg.use_adam:
            optimizer = optim.Adam(optimizer_parameters, lr=self._train_cfg.learning_rate)
        else:
            optimizer = AdamW(optimizer_parameters, lr=self._train_cfg.learning_rate)
        # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

        if self._train_cfg.fp16:
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self._train_cfg.fp16_opt_level)

        t_total = len(self._train_loader) // self._train_cfg.gradient_accumulation_steps * self._train_cfg.num_train_epochs
        warmup_steps = t_total * self._train_cfg.warmup_ratio
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
        )

        model = torch.nn.DataParallel(model)
        self._state = TrainerState(
            epoch=0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, global_step=0
        )
        self.tb_logger = SummaryWriter(self._train_cfg.output_dir.replace("logs", "tflogs"))

        checkpoint_fn = osp.join(self._train_cfg.output_dir, str(job_env.job_id), "checkpoint.pth")
        # checkpoint_fn = osp.join(self._train_cfg.output_dir, "checkpoint.pth")
        if os.path.isfile(checkpoint_fn):
            print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
            self._state = TrainerState.load(
                checkpoint_fn, default=self._state, gpu=job_env.local_rank)