in curiosity/baseline_models.py [0:0]
def forward(self,
messages: Dict[str, torch.Tensor],
# (batch_size, n_turns, n_facts, n_words)
facts: Dict[str, torch.Tensor],
# (batch_size, n_turns)
senders: torch.Tensor,
# (batch_size, n_turns, n_acts)
dialog_acts: torch.Tensor,
# (batch_size, n_turns)
dialog_acts_mask: torch.Tensor,
# (batch_size, n_entities)
known_entities: Dict[str, torch.Tensor],
# (batch_size, 1)
focus_entity: Dict[str, torch.Tensor],
# (batch_size, n_turns, n_facts)
fact_labels: Optional[torch.Tensor] = None,
# (batch_size, n_turns, 2)
likes: Optional[torch.Tensor] = None,
metadata: Optional[Dict] = None):
output = {}
# Take care of the easy stuff first
if self._use_bert:
# (batch_size, n_turns, n_words, emb_dim)
context, utter_mask = self._bert_encoder(messages)
context = self._dropout(context)
else:
# (batch_size, n_turns)
# This is the mask since not all dialogs have same number
# of turns
utter_mask = get_text_field_mask(messages)
# (batch_size, n_turns, n_words)
# Mask since not all utterances have same number of words
# Wrapping dim skips over n_messages dim
text_mask = get_text_field_mask(messages, num_wrapping_dims=1)
# (batch_size, n_turns, n_words, emb_dim)
embed = self._dropout(self._utter_embedder(messages))
# (batch_size, n_turns, hidden_dim)
context = self._dist_utter_context(embed, text_mask)
# (batch_size, n_turns, hidden_dim)
# n_turns = context.shape[1]
dialog_context = self._dialog_context(context)
# (batch_size, n_turns, hidden_dim)
# This assumes dialog_context does not peek into future
# dialog_context = self._dialog_context(full_context, utter_mask)
# shift context one right, pad with zeros at front
# This makes it so that utter_t is paired with context_t-1
# which is what we want
# This is useful in a few different places, so compute it here once
shape = dialog_context.shape
shifted_context = torch.cat((
dialog_context.new_zeros([shape[0], 1, shape[2]]),
dialog_context[:, :-1, :]
), dim=1)
has_loss = False
if self._disable_dialog_acts:
da_loss = 0
policy_loss = 0
else:
# Dialog act per utter loss
has_loss = True
da_loss = self._compute_da_loss(
output,
context, shifted_context, utter_mask,
dialog_acts, dialog_acts_mask
)
# Policy loss
policy_loss = self._compute_policy_loss(
output,
shifted_context, utter_mask,
dialog_acts, dialog_acts_mask
)
if self._disable_facts:
# If facts are disabled, don't output anything related
# to them
fact_loss = 0
else:
if self._use_bert:
# (batch_size, n_turns, n_words, emb_dim)
fact_repr, fact_mask = self._bert_encoder(facts)
fact_repr = self._dropout(fact_repr)
fact_mask[:, ::2] = 0
else:
# (batch_size, n_turns, n_facts)
# Wrapping dim skips over n_messages
fact_mask = get_text_field_mask(facts, num_wrapping_dims=1)
# In addition to masking padded facts, also explicitly mask
# user turns just in case
fact_mask[:, ::2] = 0
# (batch_size, n_turns, n_facts, n_words)
# Wrapping dim skips over n_turns and n_facts
fact_text_mask = get_text_field_mask(facts, num_wrapping_dims=2)
# (batch_size, n_turns, n_facts, n_words, emb_dim)
# Share encoder with utter encoder
# Again, stupid dimensions
fact_embed = self._dropout(self._utter_embedder(facts))
shape = fact_embed.shape
word_dim = shape[-2]
emb_dim = shape[-1]
reshaped_facts = fact_embed.view(-1, word_dim, emb_dim)
reshaped_fact_text_mask = fact_text_mask.view(-1, word_dim)
reshaped_fact_repr = self._utter_context(
reshaped_facts, reshaped_fact_text_mask
)
# No more emb dimension or word/seq dim
fact_repr = reshaped_fact_repr.view(shape[:-2] + (-1,))
fact_logits = self._fact_ranker(
shifted_context,
fact_repr,
)
output['fact_logits'] = fact_logits
if fact_labels is not None:
has_loss = True
fact_loss = self._compute_fact_loss(
fact_logits, fact_labels, fact_mask
)
self._fact_loss_metric(fact_loss.item())
self._fact_mrr(fact_logits, fact_labels, mask=fact_mask)
else:
fact_loss = 0
if self._disable_likes:
like_loss = 0
else:
has_loss = True
# (batch_size, n_turns, 2)
like_logits = self._like_classifier(dialog_context)
output['like_logits'] = like_logits
# There are several masks here to get the loss/metrics correct
# - utter_mask: mask out positions that do not have an utterance
# - user_mask: mask out positions that have a user utterances
# since their turns are never liked
# Using new_ones() preserves the type of the tensor
user_mask = utter_mask.new_ones(utter_mask.shape)
# Since the user is always even, this masks out user positions
user_mask[:, ::2] = 0
final_mask = utter_mask * user_mask
masked_likes = likes * final_mask
if likes is not None:
has_loss = True
like_loss = sequence_cross_entropy_with_logits(
like_logits, masked_likes, final_mask
)
self._like_accuracy(like_logits, masked_likes, final_mask)
self._like_loss_metric(like_loss.item())
else:
like_loss = 0
if has_loss:
output['loss'] = (
self._fact_loss_weight * fact_loss
+ like_loss
+ da_loss + policy_loss
)
return output