in train.py [0:0]
def run(model, Dataset, log_path, plotter, CHECKPOINT_tempfile):
print("Starting run...", flush=True)
opts.best_epoch = 0
opts.best_loss = -1000
if os.path.exists(opts.model_epoch_path) and opts.resume:
past_state = torch.load(opts.model_epoch_path)
print("Continuing epoch ... %d" % (past_state['opts'].continue_epoch + 1), flush=True)
model.load_state_dict(torch.load(opts.model_epoch_path)["state_dict"])
model.optimizer_D.load_state_dict(
torch.load(opts.model_epoch_path)["optimizerD"]
)
model.optimizer_G.load_state_dict(
torch.load(opts.model_epoch_path)["optimizerG"]
)
opts.continue_epoch = past_state["opts"].continue_epoch + 1
opts.current_episode_train = past_state["opts"].current_episode_train
opts.current_episode_val = past_state["opts"].current_episode_val
opts.best_epoch = past_state["opts"].best_epoch
opts.best_loss = past_state["opts"].best_loss
elif opts.resume:
print("WARNING: Model path does not exist?? ")
print(opts.model_epoch_path)
print("Loading train dataset ....", flush=True)
train_set = Dataset("train", opts)
train_data_loader = DataLoader(
dataset=train_set,
num_workers=opts.num_workers,
batch_size=opts.batch_size,
shuffle=False,
drop_last=True,
pin_memory=True,
)
print("Loaded train dataset ...", flush=True)
for epoch in range(opts.continue_epoch, opts.max_epoch):
print("Starting epoch %d" % epoch, flush=True)
opts.continue_epoch = epoch
model.epoch = epoch
model.train()
train_loss = train(
epoch, train_data_loader, model, log_path, plotter, opts
)
model.eval()
with torch.no_grad():
model.eval()
train_set.toval(
epoch=0
) # Hack because don't want to keep reloading the environments
loss = val(epoch, train_data_loader, model, log_path, plotter)
train_set.totrain(epoch=epoch + 1 + opts.seed)
for l in train_loss.keys():
if l in loss.keys():
plotter.add_scalars(
"%s_epoch" % l,
{"train": train_loss[l], "val": loss[l]},
epoch,
)
else:
plotter.add_scalars(
"%s_epoch" % l, {"train": train_loss[l]}, epoch
)
if loss["psnr"] > opts.best_loss:
checkpoint(
model, opts.model_epoch_path + "best", CHECKPOINT_tempfile
)
opts.best_epoch = epoch
opts.best_loss = loss["psnr"]
checkpoint(model, opts.model_epoch_path, CHECKPOINT_tempfile)
if epoch % 50 == 0:
checkpoint(
model,
opts.model_epoch_path + "ep%d" % epoch,
CHECKPOINT_tempfile,
)
if epoch == 500 - 1:
open(HALT_filename, "a").close()