def _init_state()

in mdr/retrieval/single_trainer.py [0:0]


    def _init_state(self) -> None:
        """
        Initialize the state and load it from an existing checkpoint if any
        """
        job_env = submitit.JobEnvironment()

        if job_env.global_rank == 0:
            # config_path = Path(args.save_folder) / str(job_env.job_id) / 'config.json'
            os.makedirs(self._train_cfg.output_dir, exist_ok=True)
            config_path = Path(self._train_cfg.output_dir)  / 'config.json'
            with open(config_path, "w") as g:
                g.write(json.dumps(self._train_cfg._asdict()))

        print(f"Setting random seed {self._train_cfg.seed}", flush=True)
        random.seed(self._train_cfg.seed)
        np.random.seed(self._train_cfg.seed)
        torch.manual_seed(self._train_cfg.seed)

        print("Create data loaders", flush=True)
        tokenizer = BertTokenizer.from_pretrained(self._train_cfg.bert_model_name)
        collate_fc = sp_collate
        train_set = SPDataset(tokenizer, self._train_cfg.train_file, self._train_cfg.max_q_len, self._train_cfg.max_c_len, train=True)
        # train_sampler = torch.utils.data.distributed.DistributedSampler(
        #     train_set, num_replicas=job_env.num_tasks, rank=job_env.global_rank
        # )
        # self._train_loader = torch.utils.data.DataLoader(
        #     train_set,
        #     batch_size=self._train_cfg.train_batch_size,
        #     num_workers=4,
        #     sampler=train_sampler, collate_fn=collate_fc
        # )
        self._train_loader = torch.utils.data.DataLoader(train_set, batch_size=self._train_cfg.train_batch_size, num_workers=4, collate_fn=collate_fc)
        test_set = SPDataset(tokenizer, self._train_cfg.predict_file, self._train_cfg.max_q_len, self._train_cfg.max_c_len)
        self._test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=self._train_cfg.predict_batch_size,
            num_workers=4, collate_fn=collate_fc
        )
        print(f"Per Node batch_size: {self._train_cfg.train_batch_size // job_env.num_tasks}", flush=True)

        print("Create model", flush=True)
        print(f"Local rank {job_env.local_rank}", flush=True)
        bert_config = BertConfig.from_pretrained(self._train_cfg.bert_model_name)
        model = BertForRetrieverSP(bert_config, self._train_cfg)
        model.cuda(job_env.local_rank)

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(
                nd in n for nd in no_decay)], 'weight_decay': self._train_cfg.weight_decay},
            {'params': [p for n, p in model.named_parameters() if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_parameters,
                          lr=self._train_cfg.learning_rate)
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5)

        if self._train_cfg.fp16:
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self._train_cfg.fp16_opt_level)
        model = torch.nn.DataParallel(model) # 
        self._state = TrainerState(
            epoch=0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, global_step=0
        )

        self.tb_logger = SummaryWriter(os.path.join(self._train_cfg.output_dir, "tblog"))

        checkpoint_fn = osp.join(self._train_cfg.output_dir, str(job_env.job_id), "checkpoint.pth")
        # checkpoint_fn = osp.join(self._train_cfg.output_dir, "checkpoint.pth")
        if os.path.isfile(checkpoint_fn):
            print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
            self._state = TrainerState.load(
                checkpoint_fn, default=self._state, gpu=job_env.local_rank)