def run()

in train.py [0:0]


def run(model, Dataset, log_path, plotter, CHECKPOINT_tempfile):
    print("Starting run...", flush=True)

    opts.best_epoch = 0
    opts.best_loss = -1000
    if os.path.exists(opts.model_epoch_path) and opts.resume:
        past_state = torch.load(opts.model_epoch_path)
        print("Continuing epoch ... %d" % (past_state['opts'].continue_epoch + 1), flush=True)
        model.load_state_dict(torch.load(opts.model_epoch_path)["state_dict"])
        model.optimizer_D.load_state_dict(
            torch.load(opts.model_epoch_path)["optimizerD"]
        )
        model.optimizer_G.load_state_dict(
            torch.load(opts.model_epoch_path)["optimizerG"]
        )

        opts.continue_epoch = past_state["opts"].continue_epoch + 1
        opts.current_episode_train = past_state["opts"].current_episode_train
        opts.current_episode_val = past_state["opts"].current_episode_val
        opts.best_epoch = past_state["opts"].best_epoch
        opts.best_loss = past_state["opts"].best_loss
    elif opts.resume:
        print("WARNING: Model path does not exist?? ")
        print(opts.model_epoch_path)

    print("Loading train dataset ....", flush=True)
    train_set = Dataset("train", opts)

    train_data_loader = DataLoader(
        dataset=train_set,
        num_workers=opts.num_workers,
        batch_size=opts.batch_size,
        shuffle=False,
        drop_last=True,
        pin_memory=True,
    )

    print("Loaded train dataset ...", flush=True)

    for epoch in range(opts.continue_epoch, opts.max_epoch):
        print("Starting epoch %d" % epoch, flush=True)
        opts.continue_epoch = epoch
        model.epoch = epoch
        model.train()

        train_loss = train(
            epoch, train_data_loader, model, log_path, plotter, opts
        )

        model.eval()
        with torch.no_grad():

            model.eval()
            train_set.toval(
                epoch=0
            )  # Hack because don't want to keep reloading the environments
            loss = val(epoch, train_data_loader, model, log_path, plotter)
            train_set.totrain(epoch=epoch + 1 + opts.seed)

        for l in train_loss.keys():
            if l in loss.keys():
                plotter.add_scalars(
                    "%s_epoch" % l,
                    {"train": train_loss[l], "val": loss[l]},
                    epoch,
                )
            else:
                plotter.add_scalars(
                    "%s_epoch" % l, {"train": train_loss[l]}, epoch
                )

        if loss["psnr"] > opts.best_loss:
            checkpoint(
                model, opts.model_epoch_path + "best", CHECKPOINT_tempfile
            )
            opts.best_epoch = epoch
            opts.best_loss = loss["psnr"]

        checkpoint(model, opts.model_epoch_path, CHECKPOINT_tempfile)

        if epoch % 50 == 0:
            checkpoint(
                model,
                opts.model_epoch_path + "ep%d" % epoch,
                CHECKPOINT_tempfile,
            )

    if epoch == 500 - 1:
        open(HALT_filename, "a").close()