train_curve.py [322:355]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    if regime_params.get("bn_accuracy_stats", True):
        hooks, mean_dict, var_dict = utils.register_bn_tracking_hooks(model)

    with torch.no_grad():

        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.to(device, non_blocking=True), target.to(
                device, non_blocking=True
            )

            output = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

    if regime_params.get("bn_accuracy_stats", True):
        utils.unregister_bn_tracking_hooks(hooks)
        extra_metrics = utils.get_bn_accuracy_metrics(
            model, mean_dict, var_dict
        )

        # Mean_dict and var_dict contain a mapping from modules to their
    else:
        extra_metrics = {}

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)
    return logging(
        model,
        test_loss,
        test_acc,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



train_indep.py [347:380]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    if regime_params.get("bn_accuracy_stats", True):
        hooks, mean_dict, var_dict = utils.register_bn_tracking_hooks(model)

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(val_loader):

            data, target = data.to(device, non_blocking=True), target.to(
                device, non_blocking=True
            )

            output = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

    if regime_params.get("bn_accuracy_stats", True):
        utils.unregister_bn_tracking_hooks(hooks)
        extra_metrics = utils.get_bn_accuracy_metrics(
            model, mean_dict, var_dict
        )

        # Mean_dict and var_dict contain a mapping from modules to their
    else:
        extra_metrics = {}

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)
    return logging(
        model,
        test_loss,
        test_acc,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



