in flsim/utils/example_utils.py [0:0]
def fl_forward(self, batch) -> FLBatchMetrics:
features = batch["features"] # [B, C, 28, 28]
batch_label = batch["labels"]
stacked_label = batch_label.view(-1).long().clone().detach()
if self.device is not None:
features = features.to(self.device)
output = self.model(features)
if self.device is not None:
output, batch_label, stacked_label = (
output.to(self.device),
batch_label.to(self.device),
stacked_label.to(self.device),
)
loss = F.cross_entropy(output, stacked_label)
num_examples = self.get_num_examples(batch)
output = output.detach().cpu()
stacked_label = stacked_label.detach().cpu()
del features
return FLBatchMetrics(
loss=loss,
num_examples=num_examples,
predictions=output,
targets=stacked_label,
model_inputs=[],
)