in train_mnist.py [0:0]
def main(config):
print("Hyper-params:")
print(config)
# create exp folder and save config
exp_dir = os.path.join(config.exp_dir, config.exp_name)
if not os.path.exists(exp_dir):
os.makedirs(exp_dir)
plots_dir = os.path.join(exp_dir, 'extra_plots')
if not os.path.exists(plots_dir):
os.makedirs(plots_dir)
if config.manualSeed is None:
config.manualSeed = random.randint(1, 10000)
print("Random Seed: ", config.manualSeed)
random.seed(config.manualSeed)
torch.manual_seed(config.manualSeed)
np.random.seed(config.manualSeed)
if torch.cuda.is_available():
torch.cuda.manual_seed(config.manualSeed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device {0!s}".format(device))
dataloader = load_mnist(config.batchSize)
eval_dataloader = load_mnist(config.batchSize, subset=5000)
eig_dataloader = load_mnist(1000, train=True, subset=1000)
fixed_noise = torch.randn(64, config.nz, 1, 1, device=device)
# define the model
netG = Generator(config.ngpu, config.nc, config.ngf, config.nz).to(device)
netG.apply(weights_init)
if config.netG != '':
print('loading generator from %s' % config.netG)
netG.load_state_dict(torch.load(config.netG)['state_gen'])
print(netG)
# sigmoid = config.model == 'dcgan'
sigmoid = False
netD = Discriminator(config.ngpu, config.nc, config.ndf, config.dnorm, sigmoid).to(device)
netD.apply(weights_init)
if config.netD != '':
print('loading discriminator from %s' % config.netD)
netD.load_state_dict(torch.load(config.netD)['state_dis'])
print(netD)
# evaluation G and D
evalG = Generator(config.ngpu, config.nc, config.ngf, config.nz).to(device)
evalG.apply(weights_init)
evalD = Discriminator(config.ngpu, config.nc, config.ndf, config.dnorm, sigmoid).to(device)
evalD.apply(weights_init)
# defining the loss function
model_loss_dis, model_loss_gen = define_model_loss(config)
# # defining learning rates based on the model
# if config.model in ['wgan', 'wgan_gp']:
# config.lrG = config.lrD / config.n_critic
# warnings.warn('modifying learning rates to lrD=%f, lrG=%f' % (config.lrD, config.lrG))
if config.lrG is None:
config.lrG = config.lrD
# setup optimizer
if config.optimizer == 'adam':
optimizerD = optim.Adam(netD.parameters(), lr=config.lrD, betas=(config.beta1, config.beta2))
optimizerG = optim.Adam(netG.parameters(), lr=config.lrG, betas=(config.beta1, config.beta2))
elif config.optimizer == 'extraadam':
optimizerD = ExtraAdam(netD.parameters(), lr=config.lrD)
optimizerG = ExtraAdam(netG.parameters(), lr=config.lrG)
elif config.optimizer == 'rmsprop':
optimizerD = optim.RMSprop(netD.parameters(), lr=config.lrD)
optimizerG = optim.RMSprop(netG.parameters(), lr=config.lrG)
elif config.optimizer == 'sgd':
optimizerD = optim.SGD(netD.parameters(), lr=config.lrD, momentum=config.beta1)
optimizerG = optim.SGD(netG.parameters(), lr=config.lrG, momentum=config.beta1)
else:
raise ValueError('Optimizer %s not supported' % config.optimizer)
with open(os.path.join(exp_dir, 'config.json'), 'w') as f:
json.dump(vars(config), f, indent=4)
summary_writer = SummaryWriter(log_dir=exp_dir)
global_step = 0
torch.save({'state_gen': netG.state_dict(),
'state_dis': netD.state_dict()},
'%s/checkpoint_step_%06d.pth' % (exp_dir, global_step))
# compute and save eigen values function
def comp_and_save_eigs(step, n_eigs=20):
eig_checkpoint = torch.load('%s/checkpoint_step_%06d.pth' % (exp_dir, step),
map_location=device)
evalG.load_state_dict(eig_checkpoint['state_gen'])
evalD.load_state_dict(eig_checkpoint['state_dis'])
gen_eigs, dis_eigs, game_eigs = \
compute_eigenvalues(evalG, evalD, eig_dataloader, config,
model_loss_gen, model_loss_dis,
device, verbose=True, n_eigs=n_eigs)
np.savez(os.path.join(plots_dir, 'eigenvalues_%d' % step),
gen_eigs=gen_eigs, dis_eigs=dis_eigs, game_eigs=game_eigs)
return gen_eigs, dis_eigs, game_eigs
if config.compute_eig:
# eigenvalues of initialization
gen_eigs_init, dis_eigs_init, game_eigs_init = comp_and_save_eigs(0)
for epoch in range(config.niter):
for i, data in enumerate(dataloader, 0):
global_step += 1
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
x_real = data[0].to(device)
batch_size = x_real.size(0)
noise = torch.randn(batch_size, config.nz, 1, 1, device=device)
x_fake = netG(noise)
errD, D_x, D_G_z1 = model_loss_dis(x_real, x_fake.detach(), netD, device)
# gradient penalty
if config.model == 'wgan_gp':
errD += config.gp_lambda * netD.get_penalty(x_real.detach(), x_fake.detach())
errD.backward()
D_x = D_x.mean().item()
D_G_z1 = D_G_z1.mean().item()
if config.optimizer == "extraadam":
if i % 2 == 0:
optimizerD.extrapolation()
else:
optimizerD.step()
else:
optimizerD.step()
# weight clipping
if config.model == 'wgan':
for p in netD.parameters():
p.data.clamp_(-config.clip, config.clip)
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
if config.model == 'dcgan' or (config.model in ['wgan', 'wgan_gp'] and i % config.n_critic == 0):
netG.zero_grad()
errG, D_G_z2 = model_loss_gen(x_fake, netD, device)
errG.backward()
D_G_z2 = D_G_z2.mean().item()
if config.optimizer == "extraadam":
if i % 2 == 0:
optimizerG.extrapolation()
else:
optimizerG.step()
else:
optimizerG.step()
if global_step % config.printFreq == 0:
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
% (epoch, config.niter, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
summary_writer.add_scalar("loss/D", errD.item(), global_step)
summary_writer.add_scalar("loss/G", errG.item(), global_step)
summary_writer.add_scalar("output/D_real", D_x, global_step)
summary_writer.add_scalar("output/D_fake", D_G_z1, global_step)
# every epoch save samples
fake = netG(fixed_noise)
# vutils.save_image(fake.detach(),
# '%s/fake_samples_step-%06d.png' % (exp_dir, global_step),
# normalize=True)
fake_grid = vutils.make_grid(fake.detach(), normalize=True)
summary_writer.add_image("G_samples", fake_grid, global_step)
# generate samples for IS evaluation
IS_fake = []
for i in range(10):
noise = torch.randn(500, config.nz, 1, 1, device=device)
IS_fake.append(netG(noise))
IS_fake = torch.cat(IS_fake)
IS_mean, IS_std = mnist_inception_score(IS_fake, device)
print("IS score: mean=%.4f, std=%.4f" % (IS_mean, IS_std))
summary_writer.add_scalar("IS_mean", IS_mean, global_step)
# do checkpointing
checkpoint = {'state_gen': netG.state_dict(),
'state_dis': netD.state_dict()}
torch.save(checkpoint, '%s/checkpoint_step_%06d.pth' % (exp_dir, global_step))
last_chkpt = '%s/checkpoint_step_%06d.pth' % (exp_dir, global_step)
if epoch == 0:
# last_chkpt = '%s/checkpoint_step_%06d.pth' % (exp_dir, 0) # for now
checkpoint_1 = torch.load(last_chkpt, map_location=device)
if config.compute_eig:
# compute eigenvalues for epoch 1, just in case
gen_eigs_curr, dis_eigs_curr, game_eigs_curr = comp_and_save_eigs(global_step)
# if (epoch + 1) % 10 == 0:
if global_step > 30000 and epoch % 5 == 0:
checkpoint_2 = torch.load(last_chkpt, map_location=device)
print("Computing path statistics...")
t = time.time()
hist = compute_path_stats(evalG, evalD, checkpoint_1, checkpoint_2, eval_dataloader,
config, model_loss_gen, model_loss_dis, device, verbose=True)
with open("%s/hist_%d.pkl" % (plots_dir, global_step), 'wb') as f:
pickle.dump(hist, f)
plot_path_stats(hist, plots_dir, summary_writer, global_step)
print("Took %.2f minutes" % ((time.time() - t) / 60.))
if config.compute_eig and global_step > 30000 and epoch % 10 == 0:
# compute eigenvalues and save them
gen_eigs_curr, dis_eigs_curr, game_eigs_curr = comp_and_save_eigs(global_step)
plot_eigenvalues([gen_eigs_init, gen_eigs_curr], [dis_eigs_init, dis_eigs_curr],
[game_eigs_init, game_eigs_curr],
['init', 'step_%d' % global_step], plots_dir, summary_writer,
step=global_step)