in model/train.py [0:0]
def evaluate(eval_iter, dis_val_iter=None, mode="eval", temperature=1):
# Turn on evaluation mode def disables dropout.
model.eval()
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
if isinstance(model, DDP):
eval_model = model.module
else:
eval_model = model
eval_model.generator.reset_length(
tgt_len=cfg.EVALUATE.tgt_length, mem_len=cfg.EVALUATE.mem_length)
eval_model.generator.same_length = True
# Evaluation
total_token_num = 0
total_nll = 0.0
# total_gen_len, total_gen_loss = 0, 0
with torch.no_grad():
mems = None
for i, (data, target, all_reset_mem, batch_token_num, status_vec) in enumerate(eval_iter()):
if all_reset_mem:
mems = None
ret = model(data, target, None, "mle", mems, status_vec=status_vec)
loss, mems = ret["mle"], ret["mems"]
loss = loss[target != dataset.vocab.pad_id]
loss = loss.mean()
total_nll += batch_token_num * loss.float().item()
total_token_num += batch_token_num
# Compute metrics (what size ref corpus do we consider)
gen_tokens = None
if cfg.METRICS.use_bleu:
gen_tokens = tensor_to_tokens(
generate_tokens(625, temperature).transpose(0, 1)
)
if mode == "eval":
all_text = [el.tolist() for el in dataset._valid_data]
else:
all_text = [el.tolist() for el in dataset._test_data]
bleu.reset(test_text=gen_tokens, real_text=all_text)
if cfg.METRICS.use_self_bleu and mode == "eval":
if not cfg.METRICS.use_bleu:
gen_tokens = tensor_to_tokens(
generate_tokens(625, temperature).transpose(0, 1)
)
gen_tokens_s = tensor_to_tokens(
generate_tokens(2500, temperature).transpose(0, 1)
)
self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens)
if cfg.METRICS.CLASSIFIER.use_classifier and mode == "eval":
gen_tokens = generate_tokens(cfg.METRICS.CLASSIFIER.gen_num_samples, temperature,
batch_size=cfg.METRICS.CLASSIFIER.gen_batch_size,
seq_len=cfg.METRICS.CLASSIFIER.gen_seq_len).transpose(0, 1)
if mode == "eval":
all_text = [el for el in dataset._valid_data]
classifier.reset(test_text=gen_tokens, real_text=all_text)
if mode == "eval":
results = [metric.get_score() for metric in eval_metrics]
if gen_tokens is not None:
del gen_tokens
else:
results = [metric.get_score() for metric in test_metrics]
# Switch back to the training mode
eval_model.generator.reset_length(cfg.TRAIN.tgt_length, cfg.TRAIN.mem_length)
eval_model.generator.same_length = cfg.MODEL.same_length
model.train()
return total_token_num, total_nll, results # , total_gen_loss, total_gen_len