in train_dense_encoder.py [0:0]
def validate_average_rank(self) -> float:
"""
Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset.
It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params)
and stores them in RAM as well as question vectors.
Then the similarity scores are calculted for the entire
num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion.
Each question's gold passage rank in that sorted list of scores is averaged across all the questions.
:return: averaged rank number
"""
logger.info("Average rank validation ...")
cfg = self.cfg
self.biencoder.eval()
distributed_factor = self.distributed_factor
if not self.dev_iterator:
self.dev_iterator = self.get_data_iterator(
cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank
)
data_iterator = self.dev_iterator
sub_batch_size = cfg.train.val_av_rank_bsz
sim_score_f = BiEncoderNllLoss.get_similarity_function()
q_represenations = []
ctx_represenations = []
positive_idx_per_question = []
num_hard_negatives = cfg.train.val_av_rank_hard_neg
num_other_negatives = cfg.train.val_av_rank_other_neg
log_result_step = cfg.train.log_batch_step
dataset = 0
for i, samples_batch in enumerate(data_iterator.iterate_ds_data()):
# samples += 1
if len(q_represenations) > cfg.train.val_av_rank_max_qs / distributed_factor:
break
if isinstance(samples_batch, Tuple):
samples_batch, dataset = samples_batch
biencoder_input = BiEncoder.create_biencoder_input2(
samples_batch,
self.tensorizer,
True,
num_hard_negatives,
num_other_negatives,
shuffle=False,
)
total_ctxs = len(ctx_represenations)
ctxs_ids = biencoder_input.context_ids
ctxs_segments = biencoder_input.ctx_segments
bsz = ctxs_ids.size(0)
# get the token to be used for representation selection
ds_cfg = self.ds_cfg.dev_datasets[dataset]
encoder_type = ds_cfg.encoder_type
rep_positions = ds_cfg.selector.get_positions(biencoder_input.question_ids, self.tensorizer)
# split contexts batch into sub batches since it is supposed to be too large to be processed in one batch
for j, batch_start in enumerate(range(0, bsz, sub_batch_size)):
q_ids, q_segments = (
(biencoder_input.question_ids, biencoder_input.question_segments) if j == 0 else (None, None)
)
if j == 0 and cfg.n_gpu > 1 and q_ids.size(0) == 1:
# if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0,
# otherwise the other input tensors will be split but only the first split will be called
continue
ctx_ids_batch = ctxs_ids[batch_start : batch_start + sub_batch_size]
ctx_seg_batch = ctxs_segments[batch_start : batch_start + sub_batch_size]
q_attn_mask = self.tensorizer.get_attn_mask(q_ids)
ctx_attn_mask = self.tensorizer.get_attn_mask(ctx_ids_batch)
with torch.no_grad():
q_dense, ctx_dense = self.biencoder(
q_ids,
q_segments,
q_attn_mask,
ctx_ids_batch,
ctx_seg_batch,
ctx_attn_mask,
encoder_type=encoder_type,
representation_token_pos=rep_positions,
)
if q_dense is not None:
q_represenations.extend(q_dense.cpu().split(1, dim=0))
ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0))
batch_positive_idxs = biencoder_input.is_positive
positive_idx_per_question.extend([total_ctxs + v for v in batch_positive_idxs])
if (i + 1) % log_result_step == 0:
logger.info(
"Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d",
i,
len(ctx_represenations),
len(q_represenations),
)
ctx_represenations = torch.cat(ctx_represenations, dim=0)
q_represenations = torch.cat(q_represenations, dim=0)
logger.info("Av.rank validation: total q_vectors size=%s", q_represenations.size())
logger.info("Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size())
q_num = q_represenations.size(0)
assert q_num == len(positive_idx_per_question)
scores = sim_score_f(q_represenations, ctx_represenations)
values, indices = torch.sort(scores, dim=1, descending=True)
rank = 0
for i, idx in enumerate(positive_idx_per_question):
# aggregate the rank of the known gold passage in the sorted results for each question
gold_idx = (indices[i] == idx).nonzero()
rank += gold_idx.item()
if distributed_factor > 1:
# each node calcuated its own rank, exchange the information between node and calculate the "global" average rank
# NOTE: the set of passages is still unique for every node
eval_stats = all_gather_list([rank, q_num], max_size=100)
for i, item in enumerate(eval_stats):
remote_rank, remote_q_num = item
if i != cfg.local_rank:
rank += remote_rank
q_num += remote_q_num
av_rank = float(rank / q_num)
logger.info("Av.rank validation: average rank %s, total questions=%d", av_rank, q_num)
return av_rank