in flsim/utils/sample_model.py [0:0]
def fl_forward(self, batch) -> FLBatchMetrics:
text = batch[TestDataSetting.TEXT_COL_NAME]
batch_label = batch[TestDataSetting.LABEL_COL_NAME]
stacked_label = torch.tensor(batch_label.view(-1), dtype=torch.long)
text_embeddings = self.dummy_embedding[text, :]
if self.use_cuda_if_available:
text_embeddings = text_embeddings.cuda()
out = self.model(text_embeddings)
if self.use_cuda_if_available:
out, batch_label, stacked_label = (
out.cuda(),
batch[TestDataSetting.LABEL_COL_NAME].cuda(),
stacked_label.cuda(),
)
loss = F.nll_loss(out, stacked_label)
# produce a large loss, so gradients are large
# this prevents unit tests from failing because of numerical issues
loss.mul_(100.0)
num_examples = self.get_num_examples(batch)
return FLBatchMetrics(
loss=loss,
num_examples=num_examples,
predictions=out,
targets=batch_label,
model_inputs=text_embeddings,
)