def compute_softmax()

in privacy_lint/attacks/shadow.py [0:0]


def compute_softmax(model: nn.Module, dataloader: DataLoader) -> torch.Tensor:
    softmaxes, labels = [], []
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    for img, target in tqdm(dataloader):
        img = img.to(device)
        outputs = F.softmax(model(img), dim=-1)

        softmaxes.append(outputs.cpu())
        labels.append(target)

    return torch.cat(softmaxes, dim=0), torch.cat(labels, dim=0)