in blink/crossencoder/train_cross.py [0:0]
def main(params):
model_output_path = params["output_path"]
if not os.path.exists(model_output_path):
os.makedirs(model_output_path)
logger = utils.get_logger(params["output_path"])
# Init model
reranker = CrossEncoderRanker(params)
tokenizer = reranker.tokenizer
model = reranker.model
# utils.save_model(model, tokenizer, model_output_path)
device = reranker.device
n_gpu = reranker.n_gpu
if params["gradient_accumulation_steps"] < 1:
raise ValueError(
"Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
params["gradient_accumulation_steps"]
)
)
# An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
# args.gradient_accumulation_steps = args.gradient_accumulation_steps // n_gpu
params["train_batch_size"] = (
params["train_batch_size"] // params["gradient_accumulation_steps"]
)
train_batch_size = params["train_batch_size"]
eval_batch_size = params["eval_batch_size"]
grad_acc_steps = params["gradient_accumulation_steps"]
# Fix the random seeds
seed = params["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if reranker.n_gpu > 0:
torch.cuda.manual_seed_all(seed)
max_seq_length = params["max_seq_length"]
context_length = params["max_context_length"]
fname = os.path.join(params["data_path"], "train.t7")
train_data = torch.load(fname)
context_input = train_data["context_vecs"]
candidate_input = train_data["candidate_vecs"]
label_input = train_data["labels"]
if params["debug"]:
max_n = 200
context_input = context_input[:max_n]
candidate_input = candidate_input[:max_n]
label_input = label_input[:max_n]
context_input = modify(context_input, candidate_input, max_seq_length)
if params["zeshel"]:
src_input = train_data['worlds'][:len(context_input)]
train_tensor_data = TensorDataset(context_input, label_input, src_input)
else:
train_tensor_data = TensorDataset(context_input, label_input)
train_sampler = RandomSampler(train_tensor_data)
train_dataloader = DataLoader(
train_tensor_data,
sampler=train_sampler,
batch_size=params["train_batch_size"]
)
fname = os.path.join(params["data_path"], "valid.t7")
valid_data = torch.load(fname)
context_input = valid_data["context_vecs"]
candidate_input = valid_data["candidate_vecs"]
label_input = valid_data["labels"]
if params["debug"]:
max_n = 200
context_input = context_input[:max_n]
candidate_input = candidate_input[:max_n]
label_input = label_input[:max_n]
context_input = modify(context_input, candidate_input, max_seq_length)
if params["zeshel"]:
src_input = valid_data["worlds"][:len(context_input)]
valid_tensor_data = TensorDataset(context_input, label_input, src_input)
else:
valid_tensor_data = TensorDataset(context_input, label_input)
valid_sampler = SequentialSampler(valid_tensor_data)
valid_dataloader = DataLoader(
valid_tensor_data,
sampler=valid_sampler,
batch_size=params["eval_batch_size"]
)
# evaluate before training
results = evaluate(
reranker,
valid_dataloader,
device=device,
logger=logger,
context_length=context_length,
zeshel=params["zeshel"],
silent=params["silent"],
)
number_of_samples_per_dataset = {}
time_start = time.time()
utils.write_to_file(
os.path.join(model_output_path, "training_params.txt"), str(params)
)
logger.info("Starting training")
logger.info(
"device: {} n_gpu: {}, distributed training: {}".format(device, n_gpu, False)
)
optimizer = get_optimizer(model, params)
scheduler = get_scheduler(params, optimizer, len(train_tensor_data), logger)
model.train()
best_epoch_idx = -1
best_score = -1
num_train_epochs = params["num_train_epochs"]
for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
tr_loss = 0
results = None
if params["silent"]:
iter_ = train_dataloader
else:
iter_ = tqdm(train_dataloader, desc="Batch")
part = 0
for step, batch in enumerate(iter_):
batch = tuple(t.to(device) for t in batch)
context_input = batch[0]
label_input = batch[1]
loss, _ = reranker(context_input, label_input, context_length)
# if n_gpu > 1:
# loss = loss.mean() # mean() to average on multi-gpu.
if grad_acc_steps > 1:
loss = loss / grad_acc_steps
tr_loss += loss.item()
if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
logger.info(
"Step {} - epoch {} average loss: {}\n".format(
step,
epoch_idx,
tr_loss / (params["print_interval"] * grad_acc_steps),
)
)
tr_loss = 0
loss.backward()
if (step + 1) % grad_acc_steps == 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), params["max_grad_norm"]
)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
logger.info("Evaluation on the development dataset")
evaluate(
reranker,
valid_dataloader,
device=device,
logger=logger,
context_length=context_length,
zeshel=params["zeshel"],
silent=params["silent"],
)
logger.info("***** Saving fine - tuned model *****")
epoch_output_folder_path = os.path.join(
model_output_path, "epoch_{}_{}".format(epoch_idx, part)
)
part += 1
utils.save_model(model, tokenizer, epoch_output_folder_path)
model.train()
logger.info("\n")
logger.info("***** Saving fine - tuned model *****")
epoch_output_folder_path = os.path.join(
model_output_path, "epoch_{}".format(epoch_idx)
)
utils.save_model(model, tokenizer, epoch_output_folder_path)
# reranker.save(epoch_output_folder_path)
output_eval_file = os.path.join(epoch_output_folder_path, "eval_results.txt")
results = evaluate(
reranker,
valid_dataloader,
device=device,
logger=logger,
context_length=context_length,
zeshel=params["zeshel"],
silent=params["silent"],
)
ls = [best_score, results["normalized_accuracy"]]
li = [best_epoch_idx, epoch_idx]
best_score = ls[np.argmax(ls)]
best_epoch_idx = li[np.argmax(ls)]
logger.info("\n")
execution_time = (time.time() - time_start) / 60
utils.write_to_file(
os.path.join(model_output_path, "training_time.txt"),
"The training took {} minutes\n".format(execution_time),
)
logger.info("The training took {} minutes\n".format(execution_time))
# save the best model in the parent_dir
logger.info("Best performance in epoch: {}".format(best_epoch_idx))
params["path_to_model"] = os.path.join(
model_output_path, "epoch_{}".format(best_epoch_idx)
)