sample_info/modules/sgd.py [46:60]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    for inputs_batch, labels_batch in tqdm(loader, desc='Computing sgd noise covariance...'):
        if isinstance(inputs_batch, torch.Tensor):
            inputs_batch = [inputs_batch]
        if not isinstance(labels_batch, list):
            labels_batch = [labels_batch]

        with torch.set_grad_enabled(True):
            outputs = model.forward(inputs=inputs_batch, labels=labels_batch, loader=loader, **kwargs)
            batch_losses, outputs = model.compute_loss(inputs=inputs_batch, labels=labels_batch, outputs=outputs,
                                                       loader=loader, dataset=loader.dataset)
            batch_total_loss = sum([loss for name, loss in batch_losses.items()])

        grad = torch.autograd.grad(batch_total_loss, model.parameters())
        if cpu:
            grad = [utils.to_cpu(v) for v in grad]
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



sample_info/modules/sgd.py [103:117]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    for inputs_batch, labels_batch in tqdm(loader, desc='Computing sgd noise covariance...'):
        if isinstance(inputs_batch, torch.Tensor):
            inputs_batch = [inputs_batch]
        if not isinstance(labels_batch, list):
            labels_batch = [labels_batch]

        with torch.set_grad_enabled(True):
            outputs = model.forward(inputs=inputs_batch, labels=labels_batch, loader=loader, **kwargs)
            batch_losses, outputs = model.compute_loss(inputs=inputs_batch, labels=labels_batch, outputs=outputs,
                                                       loader=loader, dataset=loader.dataset)
            batch_total_loss = sum([loss for name, loss in batch_losses.items()])

        grad = torch.autograd.grad(batch_total_loss, model.parameters())
        if cpu:
            grad = [utils.to_cpu(v) for v in grad]
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



