def forward()

in mmf/models/butd.py [0:0]


    def forward(self, sample_list):
        # Stores the output probabilites.
        scores = sample_list.answers.new_ones(
            (
                sample_list.answers.size(0),
                self.text_processor.max_length,
                self.vocab_size,
            ),
            dtype=torch.float,
        )

        if self.config["inference"]["type"] in ["beam_search", "nucleus_sampling"]:
            decoder = registry.get_decoder_class(self.config["inference"]["type"])(
                self.vocab, self.config
            )
            sample_list = decoder.init_batch(sample_list)

        batch_size = sample_list.image_feature_0.size(0)
        data, sample_list, timesteps = self.prepare_data(sample_list, batch_size)
        output = None
        batch_size_t = batch_size
        for t in range(timesteps):
            data, batch_size_t = self.get_data_t(t, data, batch_size_t, output)
            if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
                pi_t = data["texts"]
            else:
                pi_t = data["texts"][:, t].unsqueeze(-1)
            embedding = self.word_embedding(pi_t)
            attention_feature, _ = self.process_feature_embedding(
                "image", sample_list, embedding[:, 0, :], batch_size_t=batch_size_t
            )
            output = self.classifier(attention_feature)
            # Compute decoding
            if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
                finish, data, batch_size_t = decoder.decode(t, data, output)
                if finish:
                    break
            else:
                scores[:batch_size_t, t] = output

        model_output = {}
        if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
            results = decoder.get_result()
            results = torch.nn.functional.pad(
                results,
                (0, self.text_processor.max_length - results.size()[-1]),
                "constant",
                0,
            )
            model_output["captions"] = results
            model_output["losses"] = {}
            loss_key = "{}/{}".format(
                sample_list.dataset_name, sample_list.dataset_type
            )
            # Add a dummy loss so that loss calculation is not required
            model_output["losses"][loss_key + "/dummy_loss"] = torch.zeros(
                batch_size, device=sample_list.answers.device
            )
        else:
            model_output["scores"] = scores

        return model_output