in mmf/models/butd.py [0:0]
def forward(self, sample_list):
# Stores the output probabilites.
scores = sample_list.answers.new_ones(
(
sample_list.answers.size(0),
self.text_processor.max_length,
self.vocab_size,
),
dtype=torch.float,
)
if self.config["inference"]["type"] in ["beam_search", "nucleus_sampling"]:
decoder = registry.get_decoder_class(self.config["inference"]["type"])(
self.vocab, self.config
)
sample_list = decoder.init_batch(sample_list)
batch_size = sample_list.image_feature_0.size(0)
data, sample_list, timesteps = self.prepare_data(sample_list, batch_size)
output = None
batch_size_t = batch_size
for t in range(timesteps):
data, batch_size_t = self.get_data_t(t, data, batch_size_t, output)
if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
pi_t = data["texts"]
else:
pi_t = data["texts"][:, t].unsqueeze(-1)
embedding = self.word_embedding(pi_t)
attention_feature, _ = self.process_feature_embedding(
"image", sample_list, embedding[:, 0, :], batch_size_t=batch_size_t
)
output = self.classifier(attention_feature)
# Compute decoding
if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
finish, data, batch_size_t = decoder.decode(t, data, output)
if finish:
break
else:
scores[:batch_size_t, t] = output
model_output = {}
if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
results = decoder.get_result()
results = torch.nn.functional.pad(
results,
(0, self.text_processor.max_length - results.size()[-1]),
"constant",
0,
)
model_output["captions"] = results
model_output["losses"] = {}
loss_key = "{}/{}".format(
sample_list.dataset_name, sample_list.dataset_type
)
# Add a dummy loss so that loss calculation is not required
model_output["losses"][loss_key + "/dummy_loss"] = torch.zeros(
batch_size, device=sample_list.answers.device
)
else:
model_output["scores"] = scores
return model_output