in mdr/qa/qa_trainer.py [0:0]
def _init_state(self) -> None:
"""
Initialize the state and load it from an existing checkpoint if any
"""
job_env = submitit.JobEnvironment()
if job_env.global_rank == 0:
# config_path = Path(args.save_folder) / str(job_env.job_id) / 'config.json'
os.makedirs(self._train_cfg.output_dir, exist_ok=True)
config_path = Path(self._train_cfg.output_dir) / 'config.json'
with open(config_path, "w") as g:
g.write(json.dumps(self._train_cfg._asdict()))
print(f"Setting random seed {self._train_cfg.seed}", flush=True)
random.seed(self._train_cfg.seed)
np.random.seed(self._train_cfg.seed)
torch.manual_seed(self._train_cfg.seed)
print("Create data loaders", flush=True)
tokenizer = AutoTokenizer.from_pretrained(self._train_cfg.model_name)
collate_fc = partial(rank_collate, pad_id=tokenizer.pad_token_id)
train_set = RankingDataset(tokenizer, self._train_cfg.train_file, self._train_cfg.max_seq_len, self._train_cfg.max_q_len, train=True)
train_sampler = MhopSampler(train_set, num_neg=self._train_cfg.neg_num)
batch_size_per_gpu = (1 + self._train_cfg.neg_num) * self._train_cfg.num_q_per_gpu
n_gpu = torch.cuda.device_count()
print(f"Number of GPUs: {n_gpu}", flush=True)
print(f"Batch size per node: {batch_size_per_gpu * n_gpu}", flush=True)
self._train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size_per_gpu * n_gpu, num_workers=self._train_cfg.num_workers, collate_fn=collate_fc, sampler=train_sampler)
test_set = RankingDataset(tokenizer, self._train_cfg.predict_file, self._train_cfg.max_seq_len, self._train_cfg.max_q_len)
self._test_loader = torch.utils.data.DataLoader(
test_set,
batch_size=self._train_cfg.predict_batch_size,
num_workers=self._train_cfg.num_workers, collate_fn=collate_fc
)
print("Create model", flush=True)
print(f"Local rank {job_env.local_rank}", flush=True)
bert_config = AutoConfig.from_pretrained(self._train_cfg.model_name)
model = QAModel(bert_config, self._train_cfg)
model.cuda(job_env.local_rank)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': self._train_cfg.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if self._train_cfg.use_adam:
optimizer = optim.Adam(optimizer_parameters, lr=self._train_cfg.learning_rate)
else:
optimizer = AdamW(optimizer_parameters, lr=self._train_cfg.learning_rate)
# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
if self._train_cfg.fp16:
model, optimizer = amp.initialize(
model, optimizer, opt_level=self._train_cfg.fp16_opt_level)
t_total = len(self._train_loader) // self._train_cfg.gradient_accumulation_steps * self._train_cfg.num_train_epochs
warmup_steps = t_total * self._train_cfg.warmup_ratio
lr_scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
)
model = torch.nn.DataParallel(model)
self._state = TrainerState(
epoch=0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, global_step=0
)
self.tb_logger = SummaryWriter(self._train_cfg.output_dir.replace("logs", "tflogs"))
checkpoint_fn = osp.join(self._train_cfg.output_dir, str(job_env.job_id), "checkpoint.pth")
# checkpoint_fn = osp.join(self._train_cfg.output_dir, "checkpoint.pth")
if os.path.isfile(checkpoint_fn):
print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
self._state = TrainerState.load(
checkpoint_fn, default=self._state, gpu=job_env.local_rank)