in mdr/retrieval/single_trainer.py [0:0]
def _init_state(self) -> None:
"""
Initialize the state and load it from an existing checkpoint if any
"""
job_env = submitit.JobEnvironment()
if job_env.global_rank == 0:
# config_path = Path(args.save_folder) / str(job_env.job_id) / 'config.json'
os.makedirs(self._train_cfg.output_dir, exist_ok=True)
config_path = Path(self._train_cfg.output_dir) / 'config.json'
with open(config_path, "w") as g:
g.write(json.dumps(self._train_cfg._asdict()))
print(f"Setting random seed {self._train_cfg.seed}", flush=True)
random.seed(self._train_cfg.seed)
np.random.seed(self._train_cfg.seed)
torch.manual_seed(self._train_cfg.seed)
print("Create data loaders", flush=True)
tokenizer = BertTokenizer.from_pretrained(self._train_cfg.bert_model_name)
collate_fc = sp_collate
train_set = SPDataset(tokenizer, self._train_cfg.train_file, self._train_cfg.max_q_len, self._train_cfg.max_c_len, train=True)
# train_sampler = torch.utils.data.distributed.DistributedSampler(
# train_set, num_replicas=job_env.num_tasks, rank=job_env.global_rank
# )
# self._train_loader = torch.utils.data.DataLoader(
# train_set,
# batch_size=self._train_cfg.train_batch_size,
# num_workers=4,
# sampler=train_sampler, collate_fn=collate_fc
# )
self._train_loader = torch.utils.data.DataLoader(train_set, batch_size=self._train_cfg.train_batch_size, num_workers=4, collate_fn=collate_fc)
test_set = SPDataset(tokenizer, self._train_cfg.predict_file, self._train_cfg.max_q_len, self._train_cfg.max_c_len)
self._test_loader = torch.utils.data.DataLoader(
test_set,
batch_size=self._train_cfg.predict_batch_size,
num_workers=4, collate_fn=collate_fc
)
print(f"Per Node batch_size: {self._train_cfg.train_batch_size // job_env.num_tasks}", flush=True)
print("Create model", flush=True)
print(f"Local rank {job_env.local_rank}", flush=True)
bert_config = BertConfig.from_pretrained(self._train_cfg.bert_model_name)
model = BertForRetrieverSP(bert_config, self._train_cfg)
model.cuda(job_env.local_rank)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': self._train_cfg.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_parameters,
lr=self._train_cfg.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5)
if self._train_cfg.fp16:
model, optimizer = amp.initialize(
model, optimizer, opt_level=self._train_cfg.fp16_opt_level)
model = torch.nn.DataParallel(model) #
self._state = TrainerState(
epoch=0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, global_step=0
)
self.tb_logger = SummaryWriter(os.path.join(self._train_cfg.output_dir, "tblog"))
checkpoint_fn = osp.join(self._train_cfg.output_dir, str(job_env.job_id), "checkpoint.pth")
# checkpoint_fn = osp.join(self._train_cfg.output_dir, "checkpoint.pth")
if os.path.isfile(checkpoint_fn):
print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
self._state = TrainerState.load(
checkpoint_fn, default=self._state, gpu=job_env.local_rank)