in elq/biencoder/train_biencoder.py [0:0]
def main(params):
model_output_path = params["output_path"]
if not os.path.exists(model_output_path):
os.makedirs(model_output_path)
logger = utils.get_logger(params["output_path"])
# Init model
reranker = BiEncoderRanker(params)
tokenizer = reranker.tokenizer
model = reranker.model
device = reranker.device
n_gpu = reranker.n_gpu
if params["gradient_accumulation_steps"] < 1:
raise ValueError(
"Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
params["gradient_accumulation_steps"]
)
)
# An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
params["train_batch_size"] = (
params["train_batch_size"] // params["gradient_accumulation_steps"]
)
train_batch_size = params["train_batch_size"]
eval_batch_size = params["eval_batch_size"]
grad_acc_steps = params["gradient_accumulation_steps"]
# Fix the random seeds
seed = params["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if reranker.n_gpu > 0:
torch.cuda.manual_seed_all(seed)
# Load train data
train_samples = utils.read_dataset("train", params["data_path"])
logger.info("Read %d train samples." % len(train_samples))
logger.info("Finished reading all train samples")
# Load eval data
try:
valid_samples = utils.read_dataset("valid", params["data_path"])
except FileNotFoundError:
valid_samples = utils.read_dataset("dev", params["data_path"])
# MUST BE DIVISBLE BY n_gpus
if len(valid_samples) > 1024:
valid_subset = 1024
else:
valid_subset = len(valid_samples) - len(valid_samples) % torch.cuda.device_count()
logger.info("Read %d valid samples, choosing %d subset" % (len(valid_samples), valid_subset))
valid_data, valid_tensor_data, extra_ret_values = process_mention_data(
samples=valid_samples[:valid_subset], # use subset of valid data
tokenizer=tokenizer,
max_context_length=params["max_context_length"],
max_cand_length=params["max_cand_length"],
context_key=params["context_key"],
title_key=params["title_key"],
silent=params["silent"],
logger=logger,
debug=params["debug"],
add_mention_bounds=(not args.no_mention_bounds),
candidate_token_ids=None,
params=params,
)
candidate_token_ids = extra_ret_values["candidate_token_ids"]
valid_tensor_data = TensorDataset(*valid_tensor_data)
valid_sampler = SequentialSampler(valid_tensor_data)
valid_dataloader = DataLoader(
valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size
)
# load candidate encodings
cand_encs = None
cand_encs_index = None
if params["freeze_cand_enc"]:
cand_encs = torch.load(params['cand_enc_path'])
logger.info("Loaded saved entity encodings")
if params["debug"]:
cand_encs = cand_encs[:200]
# build FAISS index
cand_encs_index = DenseHNSWFlatIndexer(1)
cand_encs_index.deserialize_from(params['index_path'])
logger.info("Loaded FAISS index on entity encodings")
num_neighbors = 10
# evaluate before training
results = evaluate(
reranker, valid_dataloader, params,
cand_encs=cand_encs, device=device,
logger=logger, faiss_index=cand_encs_index,
)
number_of_samples_per_dataset = {}
time_start = time.time()
utils.write_to_file(
os.path.join(model_output_path, "training_params.txt"), str(params)
)
logger.info("Starting training")
logger.info(
"device: {} n_gpu: {}, distributed training: {}".format(device, n_gpu, False)
)
num_train_epochs = params["num_train_epochs"]
if params["dont_distribute_train_samples"]:
num_samples_per_batch = len(train_samples)
train_data, train_tensor_data_tuple, extra_ret_values = process_mention_data(
samples=train_samples,
tokenizer=tokenizer,
max_context_length=params["max_context_length"],
max_cand_length=params["max_cand_length"],
context_key=params["context_key"],
title_key=params["title_key"],
silent=params["silent"],
logger=logger,
debug=params["debug"],
add_mention_bounds=(not args.no_mention_bounds),
candidate_token_ids=candidate_token_ids,
params=params,
)
logger.info("Finished preparing training data")
else:
num_samples_per_batch = len(train_samples) // num_train_epochs
trainer_path = params.get("path_to_trainer_state", None)
optimizer = get_optimizer(model, params)
scheduler = get_scheduler(
params, optimizer, num_samples_per_batch,
logger
)
if trainer_path is not None and os.path.exists(trainer_path):
training_state = torch.load(trainer_path)
optimizer.load_state_dict(training_state["optimizer"])
scheduler.load_state_dict(training_state["scheduler"])
logger.info("Loaded saved training state")
model.train()
best_epoch_idx = -1
best_score = -1
logger.info("Num samples per batch : %d" % num_samples_per_batch)
for epoch_idx in trange(params["last_epoch"] + 1, int(num_train_epochs), desc="Epoch"):
tr_loss = 0
results = None
if not params["dont_distribute_train_samples"]:
start_idx = epoch_idx * num_samples_per_batch
end_idx = (epoch_idx + 1) * num_samples_per_batch
train_data, train_tensor_data_tuple, extra_ret_values = process_mention_data(
samples=train_samples[start_idx:end_idx],
tokenizer=tokenizer,
max_context_length=params["max_context_length"],
max_cand_length=params["max_cand_length"],
context_key=params["context_key"],
title_key=params["title_key"],
silent=params["silent"],
logger=logger,
debug=params["debug"],
add_mention_bounds=(not args.no_mention_bounds),
candidate_token_ids=candidate_token_ids,
params=params,
)
logger.info("Finished preparing training data for epoch {}: {} samples".format(epoch_idx, len(train_tensor_data_tuple[0])))
batch_train_tensor_data = TensorDataset(
*list(train_tensor_data_tuple)
)
if params["shuffle"]:
train_sampler = RandomSampler(batch_train_tensor_data)
else:
train_sampler = SequentialSampler(batch_train_tensor_data)
train_dataloader = DataLoader(
batch_train_tensor_data, sampler=train_sampler, batch_size=train_batch_size
)
if params["silent"]:
iter_ = train_dataloader
else:
iter_ = tqdm(train_dataloader, desc="Batch")
for step, batch in enumerate(iter_):
batch = tuple(t.to(device) for t in batch)
context_input = batch[0]
candidate_input = batch[1]
label_ids = batch[2] if params["freeze_cand_enc"] else None
mention_idxs = batch[-2]
mention_idx_mask = batch[-1]
if params["debug"] and label_ids is not None:
label_ids[label_ids > 199] = 199
cand_encs_input = None
label_input = None
mention_reps_input = None
mention_logits = None
mention_bounds = None
hard_negs_mask = None
if params["adversarial_training"]:
assert cand_encs is not None and label_ids is not None # due to params["freeze_cand_enc"] being set
'''
GET CLOSEST N CANDIDATES (AND APPROPRIATE LABELS)
'''
# (bs, num_spans, embed_size)
pos_cand_encs_input = cand_encs[label_ids.to("cpu")]
pos_cand_encs_input[~mention_idx_mask] = 0
context_outs = reranker.encode_context(
context_input, gold_mention_bounds=mention_idxs,
gold_mention_bounds_mask=mention_idx_mask,
get_mention_scores=True,
)
mention_logits = context_outs['all_mention_logits']
mention_bounds = context_outs['all_mention_bounds']
mention_reps = context_outs['mention_reps']
# mention_reps: (bs, max_num_spans, embed_size) -> masked_mention_reps: (all_pred_mentions_batch, embed_size)
masked_mention_reps = mention_reps[context_outs['mention_masks']]
# neg_cand_encs_input_idxs: (all_pred_mentions_batch, num_negatives)
_, neg_cand_encs_input_idxs = cand_encs_index.search_knn(masked_mention_reps.detach().cpu().numpy(), num_neighbors)
neg_cand_encs_input_idxs = torch.from_numpy(neg_cand_encs_input_idxs)
# set "correct" closest entities to -1
# masked_label_ids: (all_pred_mentions_batch)
masked_label_ids = label_ids[mention_idx_mask]
# neg_cand_encs_input_idxs: (max_spans_in_batch, num_negatives)
neg_cand_encs_input_idxs[neg_cand_encs_input_idxs - masked_label_ids.to("cpu").unsqueeze(-1) == 0] = -1
# reshape back tensor (extract num_spans dimension)
# (bs, num_spans, num_negatives)
neg_cand_encs_input_idxs_reconstruct = torch.zeros(label_ids.size(0), label_ids.size(1), neg_cand_encs_input_idxs.size(-1), dtype=neg_cand_encs_input_idxs.dtype)
neg_cand_encs_input_idxs_reconstruct[mention_idx_mask] = neg_cand_encs_input_idxs
neg_cand_encs_input_idxs = neg_cand_encs_input_idxs_reconstruct
# create neg_example_idx (corresponding example (in batch) for each negative)
# neg_example_idx: (bs * num_negatives)
neg_example_idx = torch.arange(neg_cand_encs_input_idxs.size(0)).unsqueeze(-1)
neg_example_idx = neg_example_idx.expand(neg_cand_encs_input_idxs.size(0), neg_cand_encs_input_idxs.size(2))
neg_example_idx = neg_example_idx.flatten()
# flatten and filter -1 (i.e. any correct/positive entities)
# neg_cand_encs_input_idxs: (bs * num_negatives, num_spans)
neg_cand_encs_input_idxs = neg_cand_encs_input_idxs.permute(0,2,1)
neg_cand_encs_input_idxs = neg_cand_encs_input_idxs.reshape(-1, neg_cand_encs_input_idxs.size(-1))
# mask invalid negatives (actually the positive example)
# (bs * num_negatives)
mask = ~((neg_cand_encs_input_idxs == -1).sum(1).bool()) # rows without any -1 entry
# deletes corresponding negative for *all* spans in that example (deletes at most 3 of 10 negatives / example)
# neg_cand_encs_input_idxs: (bs * num_negatives - invalid_negs, num_spans)
neg_cand_encs_input_idxs = neg_cand_encs_input_idxs[mask]
# neg_cand_encs_input_idxs: (bs * num_negatives - invalid_negs)
neg_example_idx = neg_example_idx[mask]
# (bs * num_negatives - invalid_negs, num_spans, embed_size)
neg_cand_encs_input = cand_encs[neg_cand_encs_input_idxs]
# (bs * num_negatives - invalid_negs, num_spans, embed_size)
neg_mention_idx_mask = mention_idx_mask[neg_example_idx]
neg_cand_encs_input[~neg_mention_idx_mask] = 0
# create input tensors (concat [pos examples, neg examples])
# (bs + bs * num_negatives, num_spans, embed_size)
mention_reps_input = torch.cat([
mention_reps, mention_reps[neg_example_idx.to(device)],
])
assert mention_reps.size(0) == pos_cand_encs_input.size(0)
# (bs + bs * num_negatives, num_spans)
label_input = torch.cat([
torch.ones(pos_cand_encs_input.size(0), pos_cand_encs_input.size(1), dtype=label_ids.dtype),
torch.zeros(neg_cand_encs_input.size(0), neg_cand_encs_input.size(1), dtype=label_ids.dtype),
]).to(device)
# (bs + bs * num_negatives, num_spans, embed_size)
cand_encs_input = torch.cat([
pos_cand_encs_input, neg_cand_encs_input,
]).to(device)
hard_negs_mask = torch.cat([mention_idx_mask, neg_mention_idx_mask])
loss, _, _, _ = reranker(
context_input, candidate_input,
cand_encs=cand_encs_input, text_encs=mention_reps_input,
mention_logits=mention_logits, mention_bounds=mention_bounds,
label_input=label_input, gold_mention_bounds=mention_idxs,
gold_mention_bounds_mask=mention_idx_mask,
hard_negs_mask=hard_negs_mask,
return_loss=True,
)
if grad_acc_steps > 1:
loss = loss / grad_acc_steps
tr_loss += loss.item()
if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
logger.info(
"Step {} - epoch {} average loss: {}\n".format(
step,
epoch_idx,
tr_loss / (params["print_interval"] * grad_acc_steps),
)
)
tr_loss = 0
loss.backward()
if (step + 1) % grad_acc_steps == 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), params["max_grad_norm"]
)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
logger.info("Evaluation on the development dataset")
loss = None # for GPU mem management
mention_reps = None
mention_reps_input = None
label_input = None
cand_encs_input = None
evaluate(
reranker, valid_dataloader, params,
cand_encs=cand_encs, device=device,
logger=logger, faiss_index=cand_encs_index,
get_losses=params["get_losses"],
)
model.train()
logger.info("\n")
logger.info("***** Saving fine - tuned model *****")
epoch_output_folder_path = os.path.join(
model_output_path, "epoch_{}".format(epoch_idx)
)
utils.save_model(model, tokenizer, epoch_output_folder_path)
torch.save({
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}, os.path.join(epoch_output_folder_path, "training_state.th"))
output_eval_file = os.path.join(epoch_output_folder_path, "eval_results.txt")
logger.info("Valid data evaluation")
results = evaluate(
reranker, valid_dataloader, params,
cand_encs=cand_encs, device=device,
logger=logger, faiss_index=cand_encs_index,
get_losses=params["get_losses"],
)
logger.info("Train data evaluation")
results = evaluate(
reranker, train_dataloader, params,
cand_encs=cand_encs, device=device,
logger=logger, faiss_index=cand_encs_index,
get_losses=params["get_losses"],
)
ls = [best_score, results["normalized_f1"]]
li = [best_epoch_idx, epoch_idx]
best_score = ls[np.argmax(ls)]
best_epoch_idx = li[np.argmax(ls)]
logger.info("\n")
execution_time = (time.time() - time_start) / 60
utils.write_to_file(
os.path.join(model_output_path, "training_time.txt"),
"The training took {} minutes\n".format(execution_time),
)
logger.info("The training took {} minutes\n".format(execution_time))
# save the best model in the parent_dir
logger.info("Best performance in epoch: {}".format(best_epoch_idx))
params["path_to_model"] = os.path.join(
model_output_path, "epoch_{}".format(best_epoch_idx)
)
utils.save_model(reranker.model, tokenizer, model_output_path)
if params["evaluate"]:
params["path_to_model"] = model_output_path
evaluate(params, cand_encs=cand_encs, logger=logger, faiss_index=cand_encs_index)