def main_function()

in train.py [0:0]


def main_function(args):

    # miscellaneous initial stuff
    global DEBUG
    DEBUG = args.debug
    torch.autograd.set_detect_anomaly(args.debug)
    if args.seed >= 0:
        np.random.seed(args.seed)

    # Load data

    if args.dataset_type == "llff":
        #images, poses, bds, render_poses, i_test = load_llff_data_multi_view(
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            factor=args.factor,
            recenter=True,
            bd_factor=args.bd_factor,
            spherify=args.spherify,
        )
        dataset_extras = _get_multi_view_helper_mappings(images.shape[0], args.datadir)
        intrinsics, image_folder = get_full_resolution_intrinsics(args, dataset_extras)
        
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print("Loaded llff", images.shape, render_poses.shape, hwf, args.datadir)

        # check if height, width, focal_x and focal_y are None. if so, use hwf to set them in intrinsics
        # do not use this for loop and the next in smallscripts. instead rely on the stored/saved version of "intrinsics"
        for camera in intrinsics.values(): # downscale according to args.factor
            camera["height"] = images.shape[1]
            camera["width"] = images.shape[2]
            if camera["focal_x"] is None:
                camera["focal_x"] = hwf[2]
            else:
                camera["focal_x"] /= args.factor
            if camera["focal_y"] is None:
                camera["focal_y"] = hwf[2]
            else:
                camera["focal_y"] /= args.factor
            camera["center_x"] /= args.factor
            camera["center_y"] /= args.factor
        # modify "intrinsics" mapping to use viewid instead of raw_view
        for raw_view in list(intrinsics.keys()):
            viewid = dataset_extras["rawview_to_viewid"][raw_view]
            new_entry = intrinsics[raw_view]
            del intrinsics[raw_view]
            intrinsics[viewid] = new_entry

        # take out chunks (args parameters: train & test block lengths)
        i_test = []  # [i_test]
        if args.test_block_size > 0 and args.train_block_size > 0:
            print(
                "splitting timesteps into training ("
                + str(args.train_block_size)
                + ") and test ("
                + str(args.test_block_size)
                + ") blocks"
            )
            num_timesteps = len(dataset_extras["raw_timesteps"])
            test_timesteps = np.concatenate(
                [
                    np.arange(
                        min(num_timesteps, blocks_start + args.train_block_size),
                        min(
                            num_timesteps,
                            blocks_start + args.train_block_size + args.test_block_size,
                        ),
                    )
                    for blocks_start in np.arange(
                        0, num_timesteps, args.train_block_size + args.test_block_size
                    )
                ]
            )
            i_test = [
                imageid
                for imageid, timestep in enumerate(
                    dataset_extras["imageid_to_timestepid"]
                )
                if timestep in test_timesteps
            ]

        i_test = np.array(i_test)
        i_val = i_test
        i_train = np.array(
            [
                i
                for i in np.arange(int(images.shape[0]))
                if (i not in i_test and i not in i_val)
            ]
        )

        print("DEFINING BOUNDS")
        # if args.no_ndc:
        near = np.ndarray.min(bds) * 0.9
        far = np.ndarray.max(bds) * 1.0
        # else:
        #    near = 0.
        #    far = 1.
        print("NEAR FAR", near, far)

    else:
        print("Unknown dataset type", args.dataset_type, "exiting")
        return

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    logdir = os.path.join(args.rootdir, args.expname, "logs/")
    expname = args.expname
    os.makedirs(logdir, exist_ok=True)
    f = os.path.join(logdir, "args.txt")
    with open(f, "w") as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write("{} = {}\n".format(arg, attr))
    if args.config is not None:
        f = os.path.join(logdir, "config.txt")
        with open(f, "w") as file:
            file.write(open(args.config, "r").read())

    # create autodecoder variables as pytorch tensors
    ray_bending_latents_list = [
        torch.zeros(args.ray_bending_latent_size).cuda()
        for _ in range(len(dataset_extras["raw_timesteps"]))
    ]
    for latent in ray_bending_latents_list:
        latent.requires_grad = True

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args, autodecoder_variables=ray_bending_latents_list
    )
    print("start: " + str(start) + " args.N_iters: " + str(args.N_iters), flush=True)

    global_step = start

    bds_dict = {
        "near": near,
        "far": far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    scripts_dict = {"near": near, "far": far, "image_folder": image_folder}

    coarse_model = render_kwargs_train["network_fn"]
    fine_model = render_kwargs_train["network_fine"]
    ray_bender = render_kwargs_train["ray_bender"]
    parallel_training = get_parallelized_training_function(
        coarse_model=coarse_model,
        latents=ray_bending_latents_list,
        fine_model=fine_model,
        ray_bender=ray_bender,
    )
    parallel_render = get_parallelized_render_function(
        coarse_model=coarse_model, fine_model=fine_model, ray_bender=ray_bender
    )  # only used by render_path() at test time, not for training/optimization

    min_point, max_point = determine_nerf_volume_extent(
        parallel_render, poses, [ intrinsics[dataset_extras["imageid_to_viewid"][imageid]] for imageid in range(poses.shape[0]) ], render_kwargs_train, args
    )
    scripts_dict["min_nerf_volume_point"] = min_point.detach().cpu().numpy().tolist()
    scripts_dict["max_nerf_volume_point"] = max_point.detach().cpu().numpy().tolist()

    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).cuda()

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    # For random ray batching
    print("get rays")
    rays = np.stack([get_rays_np(p, intrinsics[dataset_extras["imageid_to_viewid"][imageid]]) for imageid, p in enumerate(poses[:,:3,:4])], 0) # [N, ro+rd, H, W, 3]
    print("done, concats")

    # attach index information (index among all images in dataset, x and y coordinate)
    image_indices, y_coordinates, x_coordinates = np.meshgrid(
        np.arange(images.shape[0]), np.arange(intrinsics[0]["height"]), np.arange(intrinsics[0]["width"]), indexing="ij"
    )  # keep consistent with code in get_rays and get_rays_np. (0,0,0) is coordinate of the top-left corner of the first image, i.e. of [0,0,0]. each array has shape images x height x width
    additional_indices = np.stack(
        [image_indices, x_coordinates, y_coordinates], axis=-1
    )  # N x height x width x 3 (image, x, y)

    rays_rgb = np.concatenate(
        [rays, images[:, None], additional_indices[:, None]], 1
    )  # [N, ro+rd+rgb+ind, H, W, 3]

    rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])  # [N, H, W, ro+rd+rgb+ind, 3]

    # use all images
    # keep shape N x H x W x ro+rd+rgb x 3
    rays_rgb = rays_rgb.astype(np.float32)
    print(rays_rgb.shape)

    # Move training data to GPU
    poses = torch.Tensor(poses).cuda()

    # N_iters = 200000 + 1
    N_iters = args.N_iters + 1
    print("TRAIN views are", i_train)
    print("TEST views are", i_test)
    print("VAL views are", i_val)
    print("Begin", flush=True)

    # Summary writers
    # writer = SummaryWriter(os.path.join(logdir, 'summaries', expname))

    start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()

        optimizer.zero_grad()

        # reset autodecoder gradients to avoid wrong DeepSDF-style optimization. Note: this is only guaranteed to work if the optimizer is Adam
        for latent in ray_bending_latents_list:
            latent.grad = None

        # Sample random ray batch
        # Random over all images
        # use np random to samples N_rand random image IDs, x and y values
        image_indices = np.random.randint(images.shape[0], size=args.N_rand)
        x_coordinates = np.random.randint(intrinsics[0]["width"], size=args.N_rand)
        y_coordinates = np.random.randint(intrinsics[0]["height"], size=args.N_rand)

        # index rays_rgb with those values
        batch = rays_rgb[
            image_indices, y_coordinates, x_coordinates
        ]  # batch x ro+rd+rgb+ind x 3

        # push to cuda, create batch_rays, target_s, batch_pixel_indices
        batch_pixel_indices = (
            torch.Tensor(
                np.stack([image_indices, x_coordinates, y_coordinates], axis=-1)
            )
            .cuda()
            .long()
        )  # batch x 3
        batch = torch.transpose(torch.Tensor(batch).cuda(), 0, 1)  # 4 x batch x 3
        batch_rays, target_s = batch[:2], batch[2]

        losses = parallel_training(
            args,
            batch_rays[0],
            batch_rays[1],
            i,
            render_kwargs_train,
            target_s,
            global_step,
            start,
            dataset_extras,
            batch_pixel_indices,
        )

        # losses will have shape N_rays
        all_test_images_indicator = torch.zeros(images.shape[0], dtype=np.long).cuda()
        all_test_images_indicator[i_test] = 1
        all_training_images_indicator = torch.zeros(
            images.shape[0], dtype=np.long
        ).cuda()
        all_training_images_indicator[i_train] = 1
        # index with image IDs of the N_rays rays to determine weights
        current_test_images_indicator = all_test_images_indicator[
            image_indices
        ]  # N_rays
        current_training_images_indicator = all_training_images_indicator[
            image_indices
        ]  # N_rays

        # first, test_images (if sampled image IDs give non-empty indicators). mask N_rays loss with indicators, then take mean and loss backward with retain_graph=True. then None ray_bender (if existent) and Nerf grads
        if ray_bender is not None and torch.sum(current_test_images_indicator) > 0:
            masked_loss = current_test_images_indicator * losses  # N_rays
            masked_loss = torch.mean(masked_loss)
            masked_loss.backward(retain_graph=True)
            for weights in (
                list(coarse_model.parameters())
                + list([] if fine_model is None else fine_model.parameters())
                + list([] if ray_bender is None else ray_bender.parameters())
            ):
                weights.grad = None
        # next, training images (always). mask N_rays loss with indicators, then take mean and loss backward WITHOUT retain_graph=True
        masked_loss = current_training_images_indicator * losses  # N_rays
        masked_loss = torch.mean(masked_loss)
        masked_loss.backward(retain_graph=False)

        optimizer.step()

        if DEBUG:
            if torch.isnan(losses).any() or torch.isinf(losses).any():
                raise RuntimeError(str(losses))
            if torch.isnan(target_s).any() or torch.isinf(target_s).any():
                raise RuntimeError(str(torch.sum(target_s)) + " " + str(target_s))
            norm_type = 2.0
            total_gradient_norm = 0
            for p in (
                list(coarse_model.parameters())
                + list(fine_model.parameters())
                + list(ray_bender.parameters())
                + list(ray_bending_latents_list)
            ):
                if p.requires_grad and p.grad is not None:
                    param_norm = p.grad.data.norm(norm_type)
                    total_gradient_norm += param_norm.item() ** norm_type
            total_gradient_norm = total_gradient_norm ** (1.0 / norm_type)
            print(total_gradient_norm, flush=True)

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        warming_up = 1000
        if (
            global_step < warming_up
        ):  # in case images are very dark or very bright, need to keep network from initially building up so much momentum that it kills the gradient
            new_lrate /= 20.0 * (-(global_step - warming_up) / warming_up) + 1.0
        for param_group in optimizer.param_groups:
            param_group["lr"] = new_lrate
        ################################

        dt = time.time() - time0
        log_string = (
            "Step: "
            + str(global_step)
            + ", total loss: "
            + str(losses.mean().cpu().detach().numpy())
        )
        if "img_loss0" in locals():
            log_string += ", coarse loss: " + str(
                img_loss0.mean().cpu().detach().numpy()
            )
        if "img_loss" in locals():
            log_string += ", fine loss: " + str(img_loss.mean().cpu().detach().numpy())
        if "offsets_loss" in locals():
            log_string += ", offsets: " + str(
                offsets_loss.mean().cpu().detach().numpy()
            )
        if "divergence_loss" in locals():
            log_string += ", div: " + str(divergence_loss.mean().cpu().detach().numpy())
        log_string += ", time: " + str(dt)
        print(log_string, flush=True)

        # Rest is logging
        if i % args.i_weights == 0:

            all_latents = torch.zeros(0)
            for l in ray_bending_latents_list:
                all_latents = torch.cat([all_latents, l.cpu().unsqueeze(0)], 0)

            if i % 50000 == 0:
                store_extra = True
                path = os.path.join(logdir, "{:06d}.tar".format(i))
            else:
                store_extra = False
                path = os.path.join(logdir, "latest.tar")
            torch.save(
                {
                    "global_step": global_step,
                    "network_fn_state_dict": render_kwargs_train[
                        "network_fn"
                    ].state_dict(),
                    "network_fine_state_dict": None
                    if render_kwargs_train["network_fine"] is None
                    else render_kwargs_train["network_fine"].state_dict(),
                    "ray_bender_state_dict": None
                    if render_kwargs_train["ray_bender"] is None
                    else render_kwargs_train["ray_bender"].state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "ray_bending_latent_codes": all_latents,  # shape: frames x latent_size
                    "intrinsics": intrinsics,
                    "scripts_dict": scripts_dict,
                    "dataset_extras": dataset_extras,
                },
                path,
            )
            del all_latents

            if store_extra:
                shutil.copyfile(path, os.path.join(logdir, "latest.tar"))

            print("Saved checkpoints at", path)

        if i % args.i_video == 0 and i > 0:
            # Turn on testing mode
            print("rendering test set...", flush=True)
            if len(render_poses) > 0 and len(i_test) > 0 and not dataset_extras["is_multiview"]:
                with torch.no_grad():
                    if args.render_test:
                        rendering_latents = ray_bending_latents = [
                            ray_bending_latents_list[
                                dataset_extras["imageid_to_timestepid"][i]
                            ]
                            for i in i_test
                        ]
                    else:
                        rendering_latents = ray_bending_latents = [
                            ray_bending_latents_list[
                                dataset_extras["imageid_to_timestepid"][i_test[0]]
                            ]
                            for _ in range(len(render_poses))
                        ]
                    rgbs, disps = render_path(
                        render_poses,
                        [intrinsics[0] for _ in range(len(render_poses))],
                        args.chunk,
                        render_kwargs_test,
                        ray_bending_latents=rendering_latents,
                        parallelized_render_function=parallel_render,
                    )
                print("Done, saving", rgbs.shape, disps.shape)
                moviebase = os.path.join(logdir, "{}_spiral_{:06d}_".format(expname, i))
                try:
                    imageio.mimwrite(
                        moviebase + "rgb.mp4", to8b(rgbs), fps=30, quality=8
                    )
                    imageio.mimwrite(
                        moviebase + "disp.mp4",
                        to8b(disps / np.max(disps)),
                        fps=30,
                        quality=8,
                    )
                    imageio.mimwrite(
                        moviebase + "disp_jet.mp4",
                        to8b(
                            np.stack(
                                [
                                    visualize_disparity_with_jet_color_scheme(
                                        disp / np.max(disp)
                                    )
                                    for disp in disps
                                ],
                                axis=0,
                            )
                        ),
                        fps=30,
                        quality=8,
                    )
                    imageio.mimwrite(
                        moviebase + "disp_phong.mp4",
                        to8b(
                            np.stack(
                                [
                                    visualize_disparity_with_blinn_phong(
                                        disp / np.max(disp)
                                    )
                                    for disp in disps
                                ],
                                axis=0,
                            )
                        ),
                        fps=30,
                        quality=8,
                    )
                except:
                    print(
                        "imageio.mimwrite() failed. maybe ffmpeg is not installed properly?"
                    )

            if i >= N_iters + 1 - args.i_video:
                print("rendering full training set...", flush=True)
                with torch.no_grad():
                    rgbs, disps = render_path(
                        poses[i_train],
                        [intrinsics[dataset_extras["imageid_to_viewid"][imageid]] for imageid in i_train],
                        args.chunk,
                        render_kwargs_test,
                        ray_bending_latents=[
                            ray_bending_latents_list[
                                dataset_extras["imageid_to_timestepid"][i]
                            ]
                            for i in i_train
                        ],
                        parallelized_render_function=parallel_render,
                    )
                print("Done, saving", rgbs.shape, disps.shape)
                moviebase = os.path.join(
                    logdir, "{}_training_{:06d}_".format(expname, i)
                )
                try:
                    imageio.mimwrite(
                        moviebase + "rgb.mp4", to8b(rgbs), fps=30, quality=8
                    )
                    imageio.mimwrite(
                        moviebase + "disp.mp4",
                        to8b(disps / np.max(disps)),
                        fps=30,
                        quality=8,
                    )
                    imageio.mimwrite(
                        moviebase + "disp_jet.mp4",
                        to8b(
                            np.stack(
                                [
                                    visualize_disparity_with_jet_color_scheme(
                                        disp / np.max(disp)
                                    )
                                    for disp in disps
                                ],
                                axis=0,
                            )
                        ),
                        fps=30,
                        quality=8,
                    )
                    imageio.mimwrite(
                        moviebase + "disp_phong.mp4",
                        to8b(
                            np.stack(
                                [
                                    visualize_disparity_with_blinn_phong(
                                        disp / np.max(disp)
                                    )
                                    for disp in disps
                                ],
                                axis=0,
                            )
                        ),
                        fps=30,
                        quality=8,
                    )
                except:
                    print(
                        "imageio.mimwrite() failed. maybe ffmpeg is not installed properly?"
                    )

        if i % args.i_testset == 0 and i > 0:
            trainsubsavedir = os.path.join(logdir, "trainsubset_{:06d}".format(i))
            os.makedirs(trainsubsavedir, exist_ok=True)
            i_train_sub = i_train
            if i >= N_iters + 1 - args.i_video:
                i_train_sub = i_train_sub
            else:
                i_train_sub = i_train_sub[
                    :: np.maximum(1, int((len(i_train_sub) / len(i_test)) + 0.5))
                ]
            print("i_train_sub poses shape", poses[i_train_sub].shape)
            with torch.no_grad():
                render_path(
                    poses[i_train_sub],
                    [intrinsics[dataset_extras["imageid_to_viewid"][imageid]] for imageid in i_train_sub],
                    args.chunk,
                    render_kwargs_test,
                    gt_imgs=images[i_train_sub],
                    savedir=trainsubsavedir,
                    detailed_output=True,
                    ray_bending_latents=[
                        ray_bending_latents_list[
                            dataset_extras["imageid_to_timestepid"][i]
                        ]
                        for i in i_train_sub
                    ],
                    parallelized_render_function=parallel_render,
                )
            print("Saved some training images")

            if len(i_test) > 0:
                testsavedir = os.path.join(logdir, "testset_{:06d}".format(i))
                os.makedirs(testsavedir, exist_ok=True)
                print("test poses shape", poses[i_test].shape)
                with torch.no_grad():
                    render_path(
                        poses[i_test],
                        [intrinsics[dataset_extras["imageid_to_viewid"][imageid]] for imageid in i_test],
                        args.chunk,
                        render_kwargs_test,
                        gt_imgs=images[i_test],
                        savedir=testsavedir,
                        detailed_output=True,
                        ray_bending_latents=[
                            ray_bending_latents_list[
                                dataset_extras["imageid_to_timestepid"][i]
                            ]
                            for i in i_test
                        ],
                        parallelized_render_function=parallel_render,
                    )
                print("Saved test set")

        if i % args.i_print == 0:
            if "psnr" in locals():
                tqdm.write(
                    f"[TRAIN] Iter: {i} Loss: {losses.mean().item()}  PSNR: {psnr.item()}"
                )
            else:
                tqdm.write(f"[TRAIN] Iter: {i} Loss: {losses.mean().item()}")
        """
            print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))

            with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
                tf.contrib.summary.scalar('loss', loss)
                tf.contrib.summary.scalar('psnr', psnr)
                tf.contrib.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.contrib.summary.scalar('psnr0', psnr0)


            if i%args.i_img==0:

                # Log a rendered validation view to Tensorboard
                img_i=np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3,:4]
                with torch.no_grad():
                    rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
                                                        **render_kwargs_test)

                psnr = mse2psnr(img2mse(rgb, target))

                with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):

                    tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
                    tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
                    tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])

                    tf.contrib.summary.scalar('psnr_holdout', psnr)
                    tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])


                if args.N_importance > 0:

                    with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                        tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
                        tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
                        tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
        """

        global_step += 1
        print("", end="", flush=True)