in mm_action_prediction/models/assistant.py [0:0]
def forward(self, batch, mode=None):
"""Forward propagation.
Args:
batch: Dict of batch input variables.
mode: None for training or teaching forcing evaluation;
BEAMSEARCH / SAMPLE / MAX to generate text
"""
outputs = self.encoder(batch)
action_output = self.action_executor(batch, outputs)
outputs.update(action_output)
decoder_output = self.decoder(batch, outputs)
if mode:
generation_output = self.decoder.forward_beamsearch_multiple(
batch, outputs, mode
)
outputs.update(generation_output)
# If evaluating by retrieval, construct fake batch for each candidate.
# Inputs from batch used in decoder:
# assist_in, assist_out, assist_in_len, assist_mask
if self.params["retrieval_evaluation"] and not self.training:
option_scores = []
batch_size, num_rounds, num_candidates, _ = batch["candidate_in"].shape
replace_keys = ("assist_in", "assist_out", "assist_in_len", "assist_mask")
for ii in range(num_candidates):
for key in replace_keys:
new_key = key.replace("assist", "candidate")
batch[key] = batch[new_key][:, :, ii]
decoder_output = self.decoder(batch, outputs)
log_probs = torch_support.unflatten(
decoder_output["loss_token"], batch_size, num_rounds
)
option_scores.append(-1 * log_probs.sum(-1))
option_scores = torch.stack(option_scores, 2)
outputs["candidate_scores"] = [
{
"dialog_id": batch["dialog_id"][ii].item(),
"candidate_scores": [
{
"scores": [
float(kk) for kk in option_scores[ii, jj].cpu()
],
"turn_id": jj
}
for jj in range(batch["dialog_len"][ii])
]
}
for ii in range(batch_size)
]
# Local aliases.
loss_token = decoder_output["loss_token"]
pad_mask = decoder_output["pad_mask"]
if self.training:
loss_token = loss_token.sum() / (~pad_mask).sum().item()
loss_action = action_output["action_loss"]
loss_action_attr = action_output["action_attr_loss"]
loss_total = loss_action + loss_token + loss_action_attr
return {
"token": loss_token,
"action": loss_action,
"action_attr": loss_action_attr,
"total": loss_total,
}
else:
outputs.update(
{"loss_sum": loss_token.sum(), "num_tokens": (~pad_mask).sum()}
)
return outputs