def validate_batch()

in src/train.py [0:0]


def validate_batch(train_config: TrainConfig, epoch, train_loss, model_idx=-1):
    accuracies = []
    losses = []
    c_file = train_config.config_file

    for sample_data in tqdm(train_config.valid_data_loader, desc='validating batch', position=0, leave=True):
        # get sample input data
        sample_data_wrapper = create_sample_wrapper(sample_data, train_config)

        # inference
        if model_idx == -1:
            outs, inference_dicts = train_config.inference(sample_data_wrapper, gradient=False)
            out = outs[-1]
            inference_dict = inference_dicts[-1]
        else:
            f_in = train_config.f_in[model_idx]
            model = train_config.models[model_idx]

            prev_outs = []
            for prev_model_idx in range(0, model_idx):
                prev_outs.append(sample_data_wrapper.get_train_target(prev_model_idx))

            inference_dict = f_in.batch(sample_data_wrapper.get_batch_input(model_idx), prev_outs=prev_outs)
            x_batch = inference_dict[FeatureSetKeyConstants.input_feature_batch]
            out = model(x_batch)

        y_batch = sample_data_wrapper.get_train_target(model_idx)
        y_batch = y_batch.reshape(c_file.batchImages * c_file.samples, -1)

        loss_batch = train_config.losses[model_idx](out, y_batch, inference_dict=inference_dict)
        losses.append(loss_batch)
        diff = abs(out - y_batch)
        accuracy = float((diff < 0.001).sum()) / float(diff.shape[0] * diff.shape[1])
        accuracies.append(accuracy)

    loss = torch.mean(torch.tensor(losses)).item()
    accuracy = torch.mean(torch.tensor(accuracies)).item()

    print(f"\nvalidation epoch={epoch:<10} loss={loss:.8f} acc={accuracy:.8f}")
    with open(f"{train_config.logDir}logs.txt", "a") as f:
        f.write(f"epoch={epoch} loss={loss:.4f}  acc={accuracy:.8f} train_loss={train_loss:.8f}\r")
    add_header = False
    if not os.path.isfile(f"{train_config.logDir}{c_file.trainStatsName}"):
        add_header = True
    with open(f"{train_config.logDir}{c_file.trainStatsName}", 'a', newline='') as csv_file:
        fieldnames = ['epoch', 'loss', 'accuracy', 'train_loss']
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

        if add_header:
            writer.writeheader()

        writer.writerow({'epoch': f'{epoch}', 'loss': f'{loss}', 'accuracy': f'{accuracy}',
                         'train_loss': f'{train_loss}'})

    plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', 'loss')
    plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', 'train_loss')
    plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', 'accuracy')
    plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', ['loss', 'train_loss', 'accuracy'])
    plot_training_stats(train_config.logDir, c_file.trainStatsName, 'epoch', ['loss', 'train_loss'])
    return loss