def iterate_val_data()

in tbsm_pytorch.py [0:0]


def iterate_val_data(val_ld, tbsm, use_gpu, device):
    # NOTE: call to tbsm.eval() not needed here, see
    # https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615
    total_loss_val = 0
    total_accu_test = 0
    total_samp_test = 0

    for _, (X, lS_o, lS_i, T_test) in enumerate(val_ld):
        batchSize = X[0].shape[0]

        Z_test = tbsm(*data_wrap(X,
            lS_o,
            lS_i,
            use_gpu,
            device
        ))

        # # compute loss and accuracy
        z = Z_test.detach().cpu().numpy()  # numpy array
        t = T_test.detach().cpu().numpy()  # numpy array
        A_test = np.sum((np.round(z, 0) == t).astype(np.uint8))
        total_accu_test += A_test
        total_samp_test += batchSize

        E_test = loss_fn_wrap(Z_test, T_test, use_gpu, device)
        L_test = E_test.detach().cpu().numpy()  # numpy array
        total_loss_val += (L_test * batchSize)

    return total_accu_test, total_samp_test, total_loss_val