def forward()

in minihack/agent/polybeast/models/intrinsic.py [0:0]


    def forward(self, inputs, core_state, learning=False):
        if not learning:
            # no need to calculate RND outputs when not in learn step
            return super(RNDNet, self).forward(inputs, core_state, learning)
        T, B, *_ = inputs["glyphs"].shape

        glyphs, features = self.prepare_input(inputs)

        # -- [B x 2] x,y coordinates
        coordinates = features[:, :2]

        features = features.view(T * B, -1).float()
        # -- [B x K]
        features_emb = self.embed_features(features)

        assert features_emb.shape[0] == T * B

        reps = [features_emb]

        # -- [B x H' x W']
        crop = self.glyph_embedding.GlyphTuple(
            *[self.crop(g, coordinates) for g in glyphs]
        )
        # -- [B x H' x W' x K]
        crop_emb = self.glyph_embedding(crop)

        if self.crop_model == "transformer":
            # -- [B x W' x H' x K]
            crop_rep = self.extract_crop_representation(crop_emb, mask=None)
        elif self.crop_model == "cnn":
            # -- [B x K x W' x H']
            crop_emb = crop_emb.transpose(1, 3)
            # -- [B x W' x H' x K]
            crop_rep = self.extract_crop_representation(crop_emb)
        # -- [B x K']

        crop_rep = crop_rep.view(T * B, -1)
        assert crop_rep.shape[0] == T * B

        reps.append(crop_rep)

        # -- [B x H x W x K]
        glyphs_emb = self.glyph_embedding(glyphs)
        # -- [B x K x W x H]
        glyphs_emb = glyphs_emb.transpose(1, 3)
        # -- [B x W x H x K]
        glyphs_rep = self.extract_representation(glyphs_emb)

        # -- [B x K']
        glyphs_rep = glyphs_rep.view(T * B, -1)
        if self.equalize_input_dim:
            glyphs_rep = self.project_glyph_dim(glyphs_rep)

        assert glyphs_rep.shape[0] == T * B

        # -- [B x K'']
        reps.append(glyphs_rep)

        # MESSAGING MODEL
        if self.msg_model != "none":
            # [T x B x 256] -> [T * B x 256]
            messages = inputs["message"].long().view(T * B, -1)
            if self.msg_model == "cnn":
                # convert messages to one-hot, [T * B x 96 x 256]
                one_hot = F.one_hot(messages, num_classes=NUM_CHARS).transpose(
                    1, 2
                )
                char_rep = self.conv2_6_fc(self.conv1(one_hot.float()))
            elif self.msg_model == "lt_cnn":
                # [ T * B x E x 256 ]
                char_emb = self.char_lt(messages).transpose(1, 2)
                char_rep = self.conv2_6_fc(self.conv1(char_emb))
            else:  # lstm, gru
                char_emb = self.char_lt(messages)
                output = self.char_rnn(char_emb)[0]
                fwd_rep = output[:, -1, : self.h_dim // 2]
                bwd_rep = output[:, 0, self.h_dim // 2 :]
                char_rep = torch.cat([fwd_rep, bwd_rep], dim=1)

            if self.equalize_input_dim:
                char_rep = self.project_msg_dim(char_rep)
            reps.append(char_rep)

        st = torch.cat(reps, dim=1)

        # -- [B x K]
        st = self.fc(st)

        # TARGET NETWORK
        with torch.no_grad():
            if self.intrinsic_input == "crop_only":
                tgt_crop_emb = self.rndtgt_embed(crop).transpose(1, 3)
                tgt_crop_rep = self.rndtgt_extract_crop_representation(
                    tgt_crop_emb
                )
                tgt_st = self.rndtgt_fc(tgt_crop_rep.view(T * B, -1))
            elif self.intrinsic_input == "glyph_only":
                tgt_glyphs_emb = self.rndtgt_embed(glyphs).transpose(1, 3)
                tgt_glyphs_rep = self.rndtgt_extract_representation(
                    tgt_glyphs_emb
                )
                tgt_st = self.rndtgt_fc(tgt_glyphs_rep.view(T * B, -1))
            else:  # full
                tgt_reps = []
                tgt_feats = self.rndtgt_embed_features(features)
                tgt_reps.append(tgt_feats)

                tgt_crop_emb = self.rndtgt_embed(crop).transpose(1, 3)
                tgt_crop_rep = self.rndtgt_extract_crop_representation(
                    tgt_crop_emb
                )
                tgt_reps.append(tgt_crop_rep.view(T * B, -1))

                tgt_glyphs_emb = self.rndtgt_embed(glyphs).transpose(1, 3)
                tgt_glyphs_rep = self.rndtgt_extract_representation(
                    tgt_glyphs_emb
                )
                tgt_reps.append(tgt_glyphs_rep.view(T * B, -1))

                if self.msg_model != "none":
                    tgt_char_emb = self.rndtgt_char_lt(messages).transpose(
                        1, 2
                    )
                    tgt_char_rep = self.rndtgt_conv2_6_fc(
                        self.rndprd_conv1(tgt_char_emb)
                    )
                    tgt_reps.append(tgt_char_rep)

                tgt_st = self.rndtgt_fc(torch.cat(tgt_reps, dim=1))

        # PREDICTOR NETWORK
        if self.intrinsic_input == "crop_only":
            prd_crop_emb = self.rndprd_embed(crop).transpose(1, 3)
            prd_crop_rep = self.rndprd_extract_crop_representation(
                prd_crop_emb
            )
            prd_st = self.rndprd_fc(prd_crop_rep.view(T * B, -1))
        elif self.intrinsic_input == "glyph_only":
            prd_glyphs_emb = self.rndprd_embed(glyphs).transpose(1, 3)
            prd_glyphs_rep = self.rndprd_extract_representation(prd_glyphs_emb)
            prd_st = self.rndprd_fc(prd_glyphs_rep.view(T * B, -1))
        else:  # full
            prd_reps = []
            prd_feats = self.rndprd_embed_features(features)
            prd_reps.append(prd_feats)

            prd_crop_emb = self.rndprd_embed(crop).transpose(1, 3)
            prd_crop_rep = self.rndprd_extract_crop_representation(
                prd_crop_emb
            )
            prd_reps.append(prd_crop_rep.view(T * B, -1))

            prd_glyphs_emb = self.rndprd_embed(glyphs).transpose(1, 3)
            prd_glyphs_rep = self.rndprd_extract_representation(prd_glyphs_emb)
            prd_reps.append(prd_glyphs_rep.view(T * B, -1))

            if self.msg_model != "none":
                prd_char_emb = self.rndprd_char_lt(messages).transpose(1, 2)
                prd_char_rep = self.rndprd_conv2_6_fc(
                    self.rndprd_conv1(prd_char_emb)
                )
                prd_reps.append(prd_char_rep)

            prd_st = self.rndprd_fc(torch.cat(prd_reps, dim=1))

        assert tgt_st.size() == prd_st.size()

        if self.use_lstm:
            core_input = st.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                # Reset core state to zero whenever an episode ended.
                # Make `done` broadcastable with (num_layers, B, hidden_size)
                # states:
                nd = nd.view(1, -1, 1)
                core_state = tuple(nd * t for t in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = st

        # -- [B x A]
        policy_logits = self.policy(core_output)
        # -- [B x A]
        baseline = self.baseline(core_output)

        if self.training:
            action = torch.multinomial(
                F.softmax(policy_logits, dim=1), num_samples=1
            )
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        output = dict(
            policy_logits=policy_logits,
            baseline=baseline,
            action=action,
            target=tgt_st.view(T, B, -1),
            predicted=prd_st.view(T, B, -1),
            int_baseline=self.int_baseline(core_output).view(T, B),
        )
        return (output, core_state)