def train()

in train.py [0:0]


def train(epoch, data_loader, model, log_path, plotter, opts):
    print("At train", flush=True)

    losses = {}
    iter_data_loader = iter(data_loader)

    for iteration in range(1, min(501, len(data_loader))):
        t_losses, output_image = model(
            iter_data_loader, isval=False, num_steps=opts.num_accumulations
        )

        for l in t_losses.keys():
            if l in losses.keys():
                losses[l] = t_losses[l].cpu().mean().detach().item() + losses[l]
            else:
                losses[l] = t_losses[l].cpu().mean().detach().item()
        if iteration % 250 == 0 or iteration == 1:
            for add_im in output_image.keys():
                if iteration == 1 and os.environ['DEBUG']:
                    torchvision.utils.save_image(
                        output_image[add_im][0:8, :, :, :].cpu().data,
                        "./debug/Image_train/%d_%s.png" % (iteration, add_im),
                        normalize=("Depth" in add_im),
                    )
                
                plotter.add_image(
                    "Image_train/%d_%s" % (iteration, add_im),
                    torchvision.utils.make_grid(
                        output_image[add_im][0:8, :, :, :].cpu().data,
                        normalize=("Depth" in add_im),
                    ),
                    epoch,
                )

        if iteration % 1 == 0:
            str_to_print = "Train: Epoch {}: {}/{} with ".format(
                epoch, iteration, len(data_loader)
            )
            for l in losses.keys():
                str_to_print += " %s : %0.4f | " % (
                    l,
                    losses[l] / float(iteration),
                )
            print(str_to_print, flush=True)

        if SIGNAL_RECEIVED:
            # checkpoint(model, opts.model_epoch_path, CHECKPOINT_tempfile)
            trigger_job_requeue()
            raise SystemExit

        for l in t_losses.keys():
            plotter.add_scalars(
                "%s_iter" % l,
                {"train": t_losses[l].cpu().mean().detach().item()},
                epoch * 500 + iteration,
            )

    return {l: losses[l] / float(iteration) for l in losses.keys()}