def forward()

in curiosity/baseline_models.py [0:0]


    def forward(self,
                messages: Dict[str, torch.Tensor],
                # (batch_size, n_turns, n_facts, n_words)
                facts: Dict[str, torch.Tensor],
                # (batch_size, n_turns)
                senders: torch.Tensor,
                # (batch_size, n_turns, n_acts)
                dialog_acts: torch.Tensor,
                # (batch_size, n_turns)
                dialog_acts_mask: torch.Tensor,
                # (batch_size, n_entities)
                known_entities: Dict[str, torch.Tensor],
                # (batch_size, 1)
                focus_entity: Dict[str, torch.Tensor],
                # (batch_size, n_turns, n_facts)
                fact_labels: Optional[torch.Tensor] = None,
                # (batch_size, n_turns, 2)
                likes: Optional[torch.Tensor] = None,
                metadata: Optional[Dict] = None):
        output = {}
        # Take care of the easy stuff first

        if self._use_bert:
            # (batch_size, n_turns, n_words, emb_dim)
            context, utter_mask = self._bert_encoder(messages)
            context = self._dropout(context)
        else:
            # (batch_size, n_turns)
            # This is the mask since not all dialogs have same number
            # of turns
            utter_mask = get_text_field_mask(messages)

            # (batch_size, n_turns, n_words)
            # Mask since not all utterances have same number of words
            # Wrapping dim skips over n_messages dim
            text_mask = get_text_field_mask(messages, num_wrapping_dims=1)
            # (batch_size, n_turns, n_words, emb_dim)
            embed = self._dropout(self._utter_embedder(messages))
            # (batch_size, n_turns, hidden_dim)
            context = self._dist_utter_context(embed, text_mask)

        # (batch_size, n_turns, hidden_dim)
        # n_turns = context.shape[1]
        dialog_context = self._dialog_context(context)

        # (batch_size, n_turns, hidden_dim)
        # This assumes dialog_context does not peek into future
        # dialog_context = self._dialog_context(full_context, utter_mask)

        # shift context one right, pad with zeros at front
        # This makes it so that utter_t is paired with context_t-1
        # which is what we want
        # This is useful in a few different places, so compute it here once
        shape = dialog_context.shape
        shifted_context = torch.cat((
            dialog_context.new_zeros([shape[0], 1, shape[2]]),
            dialog_context[:, :-1, :]
        ), dim=1)
        has_loss = False

        if self._disable_dialog_acts:
            da_loss = 0
            policy_loss = 0
        else:
            # Dialog act per utter loss
            has_loss = True
            da_loss = self._compute_da_loss(
                output,
                context, shifted_context, utter_mask,
                dialog_acts, dialog_acts_mask
            )
            # Policy loss
            policy_loss = self._compute_policy_loss(
                output,
                shifted_context, utter_mask,
                dialog_acts, dialog_acts_mask
            )

        if self._disable_facts:
            # If facts are disabled, don't output anything related
            # to them
            fact_loss = 0
        else:
            if self._use_bert:
                # (batch_size, n_turns, n_words, emb_dim)
                fact_repr, fact_mask = self._bert_encoder(facts)
                fact_repr = self._dropout(fact_repr)
                fact_mask[:, ::2] = 0
            else:
                # (batch_size, n_turns, n_facts)
                # Wrapping dim skips over n_messages
                fact_mask = get_text_field_mask(facts, num_wrapping_dims=1)
                # In addition to masking padded facts, also explicitly mask
                # user turns just in case
                fact_mask[:, ::2] = 0

                # (batch_size, n_turns, n_facts, n_words)
                # Wrapping dim skips over n_turns and n_facts
                fact_text_mask = get_text_field_mask(facts, num_wrapping_dims=2)
                # (batch_size, n_turns, n_facts, n_words, emb_dim)
                # Share encoder with utter encoder
                # Again, stupid dimensions
                fact_embed = self._dropout(self._utter_embedder(facts))
                shape = fact_embed.shape
                word_dim = shape[-2]
                emb_dim = shape[-1]
                reshaped_facts = fact_embed.view(-1, word_dim, emb_dim)
                reshaped_fact_text_mask = fact_text_mask.view(-1, word_dim)
                reshaped_fact_repr = self._utter_context(
                    reshaped_facts, reshaped_fact_text_mask
                )
                # No more emb dimension or word/seq dim
                fact_repr = reshaped_fact_repr.view(shape[:-2] + (-1,))

            fact_logits = self._fact_ranker(
                shifted_context,
                fact_repr,
            )
            output['fact_logits'] = fact_logits
            if fact_labels is not None:
                has_loss = True
                fact_loss = self._compute_fact_loss(
                    fact_logits, fact_labels, fact_mask
                )
                self._fact_loss_metric(fact_loss.item())
                self._fact_mrr(fact_logits, fact_labels, mask=fact_mask)
            else:
                fact_loss = 0

        if self._disable_likes:
            like_loss = 0
        else:
            has_loss = True
            # (batch_size, n_turns, 2)
            like_logits = self._like_classifier(dialog_context)
            output['like_logits'] = like_logits

            # There are several masks here to get the loss/metrics correct
            # - utter_mask: mask out positions that do not have an utterance
            # - user_mask: mask out positions that have a user utterances
            #              since their turns are never liked
            # Using new_ones() preserves the type of the tensor
            user_mask = utter_mask.new_ones(utter_mask.shape)

            # Since the user is always even, this masks out user positions
            user_mask[:, ::2] = 0
            final_mask = utter_mask * user_mask
            masked_likes = likes * final_mask
            if likes is not None:
                has_loss = True
                like_loss = sequence_cross_entropy_with_logits(
                    like_logits, masked_likes, final_mask
                )
                self._like_accuracy(like_logits, masked_likes, final_mask)
                self._like_loss_metric(like_loss.item())
            else:
                like_loss = 0

        if has_loss:
            output['loss'] = (
                self._fact_loss_weight * fact_loss
                + like_loss
                + da_loss + policy_loss
            )

        return output