def generate_data()

in src/evaluate.py [0:0]


def generate_data(train_config, flags, reference_video=None):
    count_flops = False
    image_macs = []
    image_macs_pp = []

    dim_h = train_config.dataset_info.h
    dim_w = train_config.dataset_info.w

    if reference_video is not None:
        # scale image data if config has scale != 1
        if train_config.dataset_info.scale != 1:
            video_scaled = np.zeros((len(reference_video), dim_h, dim_w, 3))
            for i, img in enumerate(reference_video):
                video_scaled[i] = cv2.resize(img, (dim_h, dim_w), interpolation=cv2.INTER_AREA)

            reference_video = video_scaled

        train_config.store_camera_options()
        train_config.config_file.camPath = "cam_path"
        train_config.config_file.camType = "PredefinedCamera"
        train_config.config_file.videoFrames = -1

        dataset = CameraViewCellDataset(train_config.config_file, train_config, train_config.dataset_info)

        desc = "Generating diff and flip videos"
        suffix = "video"
    else:
        os.makedirs(f"{train_config.outDir}/eval/", exist_ok=True)

        if "complexity" in flags:
            count_flops = True
            for i in range(len(train_config.models)):
                train_config.models[i] = flops_counter.add_flops_counting_methods(train_config.models[i])

        dataset = train_config.test_dataset

        desc = "Generating diff and flip images"
        suffix = "images"

    q_cont = QualityContainer()

    # generate test image here to reduce time
    for i in tqdm(range(len(dataset)), desc=desc, position=0, leave=True):
        # acquire test and reference data
        img_data = create_sample_wrapper(dataset[i], train_config, True)

        if count_flops:
            for k in range(len(train_config.models)):
                train_config.models[k].start_flops_count(ost=None, verbose=False, ignore_list=[])

        test = torch.zeros((dim_h * dim_w, 3), device=train_config.device, dtype=torch.float32)
        start_index = 0
        for batch in img_data.batches(train_config.config_file.inferenceChunkSize):
            img_part, _ = train_config.inference(batch, gradient=False, is_inference=True)

            end_index = min(start_index + train_config.config_file.inferenceChunkSize, dim_w * dim_h)

            test[start_index:end_index, :3] = img_part[-1][:train_config.config_file.inferenceChunkSize, :3]
            start_index = end_index

        if count_flops:
            total_macs = 0
            for k in range(len(train_config.models)):
                macs, params = train_config.models[k].compute_average_flops_cost()
                total_macs += macs
                train_config.models[k].stop_flops_count()

            image_macs.append(total_macs * train_config.dataset_info.w * train_config.dataset_info.h)
            image_macs_pp.append(total_macs)

        # just so we can save the created images for debug purposes
        if flags is not None and "output_images" in flags:
            test_clone = torch.clone(test)
            test_clone = t2np(test_clone)

            test_clone = np.clip(test_clone.reshape(train_config.dataset_info.h,
                                                    train_config.dataset_info.w, -1), 0., 1.)[None]

            if q_cont.out_data is None:
                q_cont.out_data = test_clone
            else:
                q_cont.out_data = np.concatenate((q_cont.out_data, test_clone), axis=0)
            del test_clone

        if reference_video is None:
            reference = img_data.get_train_target(-1)
        else:
            reference = (reference_video[i]).astype(np.float32)
            reference = reference / 255
            reference = torch.from_numpy(reference).to(train_config.device)

        # generate data in separate functions to get rid of intermediate values
        generate_diff_data(train_config, test, reference, q_cont, reference_video is None, flags)
        if "flip" in flags:
            generate_flip_data(train_config, test, reference, q_cont, reference_video is None)

    # save data to disk
    if reference_video is not None:
        train_config.restore_camera_options()

        print("Saving diff and flip videos")
        imageio.mimwrite(os.path.join(train_config.outDir, f"_diff.mp4"), q_cont.diff_data, fps=30, quality=8)
        imageio.mimwrite(os.path.join(train_config.outDir, f"_square_diff.mp4"), q_cont.square_diff_data, fps=30, quality=8)
        if "flip" in flags:
            imageio.mimwrite(os.path.join(train_config.outDir, f"_flip.mp4"), q_cont.flip_data, fps=30, quality=8)
        if q_cont.out_data is not None:
            imageio.mimwrite(os.path.join(train_config.outDir, f"_out.mp4"), q_cont.out_data, fps=30, quality=8)
    else:
        for i in tqdm(range(len(q_cont.diff_data)), desc="Saving diff and flip images", position=0, leave=True):
            save_img(q_cont.diff_data[i], train_config.dataset_info, f"{train_config.outDir}/eval/{i}_diff_{q_cont.diff_data[i].mean()}.png")
            save_img(q_cont.square_diff_data[i], train_config.dataset_info, f"{train_config.outDir}/eval/{i}_square_diff_{q_cont.square_diff_data[i].mean()}.png")
            if "flip" in flags:
                save_img(q_cont.flip_data[i], train_config.dataset_info, f"{train_config.outDir}/eval/{i}_flip_{q_cont.flip_data[i].mean()}.png")
            if q_cont.out_data is not None:
                save_img(q_cont.out_data[i], train_config.dataset_info, f"{train_config.outDir}/eval/{i}_out.png")

    if count_flops:
        with open(os.path.join(train_config.outDir, "complexity.txt"), "w") as f:
            cma_macs = 0
            cma_macs_pp = 0

            for idx in range(len(image_macs)):
                macs = image_macs[idx]
                macs_pp = image_macs_pp[idx]

                f.write(f"{idx} - {macs} - {macs_pp}\n")

                # to possibly avoid overflows, we calculate the cumulative moving averages only
                cma_macs = cma_macs + (macs - cma_macs) / (idx + 1)
                cma_macs_pp = cma_macs_pp + (macs_pp - cma_macs_pp) / (idx + 1)

            f.write(f"{cma_macs} : {cma_macs_pp}\n")

    # write quality info to .txt and .csv file
    with open(f"{train_config.outDir}/image_quality_{suffix}.txt", "w") as f:
        for idx, mse in enumerate(q_cont.mse):
            f.write(f"image={idx} mse={mse:.4f} psnr="
                    f"{q_cont.psnr[idx] if 'psnr' in flags else -1.:.4f} "
                    f"ssim="
                    f"{q_cont.ssim[idx] if 'ssim' in flags else -1.:.4f} "
                    f"flip_loss="
                    f"{q_cont.flip[idx] if 'flip' in flags else -1.:.4f}\r")

    with open(f"{train_config.outDir}/image_quality_{suffix}.csv", "w") as c:
        c.write(f"mse,psnr,ssim,flip\r")
        for idx, mse in enumerate(q_cont.mse):
            c.write(f"{mse},{q_cont.psnr[idx] if 'psnr' in flags else -1.},"
                    f"{q_cont.ssim[idx] if 'ssim' in flags else -1.},"
                    f"{q_cont.flip[idx] if 'flip' in flags else -1.}\r")