in train_dense_encoder.py [0:0]
def validate_nll(self) -> float:
logger.info("NLL validation ...")
cfg = self.cfg
self.biencoder.eval()
if not self.dev_iterator:
self.dev_iterator = self.get_data_iterator(
cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank
)
data_iterator = self.dev_iterator
total_loss = 0.0
start_time = time.time()
total_correct_predictions = 0
num_hard_negatives = cfg.train.hard_negatives
num_other_negatives = cfg.train.other_negatives
log_result_step = cfg.train.log_batch_step
batches = 0
dataset = 0
for i, samples_batch in enumerate(data_iterator.iterate_ds_data()):
if isinstance(samples_batch, Tuple):
samples_batch, dataset = samples_batch
logger.info("Eval step: %d ,rnk=%s", i, cfg.local_rank)
biencoder_input = BiEncoder.create_biencoder_input2(
samples_batch,
self.tensorizer,
True,
num_hard_negatives,
num_other_negatives,
shuffle=False,
)
# get the token to be used for representation selection
ds_cfg = self.ds_cfg.dev_datasets[dataset]
rep_positions = ds_cfg.selector.get_positions(biencoder_input.question_ids, self.tensorizer)
encoder_type = ds_cfg.encoder_type
loss, correct_cnt = _do_biencoder_fwd_pass(
self.biencoder,
biencoder_input,
self.tensorizer,
cfg,
encoder_type=encoder_type,
rep_positions=rep_positions,
)
total_loss += loss.item()
total_correct_predictions += correct_cnt
batches += 1
if (i + 1) % log_result_step == 0:
logger.info(
"Eval step: %d , used_time=%f sec., loss=%f ",
i,
time.time() - start_time,
loss.item(),
)
total_loss = total_loss / batches
total_samples = batches * cfg.train.dev_batch_size * self.distributed_factor
correct_ratio = float(total_correct_predictions / total_samples)
logger.info(
"NLL Validation: loss = %f. correct prediction ratio %d/%d ~ %f",
total_loss,
total_correct_predictions,
total_samples,
correct_ratio,
)
return total_loss