in example/mxnet_adversarial_vae/vaegan_mxnet.py [0:0]
def train(dataset, nef, ndf, ngf, nc, batch_size, Z, lr, beta1, epsilon, ctx, check_point, g_dl_weight, output_path, checkpoint_path, data_path, activation,num_epoch, save_after_every, visualize_after_every, show_after_every):
'''adversarial training of the VAE
'''
#encoder
z_mu, z_lv, z = encoder(nef, Z, batch_size)
symE = mx.sym.Group([z_mu, z_lv, z])
#generator
symG = generator(ngf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12, z_dim = Z, activation=activation )
#discriminator
h = discriminator1(ndf)
dloss = discriminator2(ndf)
symD1 = h
symD2 = dloss
# ==============data==============
X_train, _ = get_data(data_path, activation)
train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size, shuffle=True)
rand_iter = RandIter(batch_size, Z)
label = mx.nd.zeros((batch_size,), ctx=ctx)
# =============module E=============
modE = mx.mod.Module(symbol=symE, data_names=('data',), label_names=None, context=ctx)
modE.bind(data_shapes=train_iter.provide_data)
modE.init_params(initializer=mx.init.Normal(0.02))
modE.init_optimizer(
optimizer='adam',
optimizer_params={
'learning_rate': lr,
'wd': 1e-6,
'beta1': beta1,
'epsilon': epsilon,
'rescale_grad': (1.0/batch_size)
})
mods = [modE]
# =============module G=============
modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx)
modG.bind(data_shapes=rand_iter.provide_data, inputs_need_grad=True)
modG.init_params(initializer=mx.init.Normal(0.02))
modG.init_optimizer(
optimizer='adam',
optimizer_params={
'learning_rate': lr,
'wd': 1e-6,
'beta1': beta1,
'epsilon': epsilon,
})
mods.append(modG)
# =============module D=============
modD1 = mx.mod.Module(symD1, label_names=[], context=ctx)
modD2 = mx.mod.Module(symD2, label_names=('label',), context=ctx)
modD = mx.mod.SequentialModule()
modD.add(modD1).add(modD2, take_labels=True, auto_wiring=True)
modD.bind(data_shapes=train_iter.provide_data,
label_shapes=[('label', (batch_size,))],
inputs_need_grad=True)
modD.init_params(initializer=mx.init.Normal(0.02))
modD.init_optimizer(
optimizer='adam',
optimizer_params={
'learning_rate': lr,
'wd': 1e-3,
'beta1': beta1,
'epsilon': epsilon,
'rescale_grad': (1.0/batch_size)
})
mods.append(modD)
# =============module DL=============
symDL = DiscriminatorLayerLoss()
modDL = mx.mod.Module(symbol=symDL, data_names=('data',), label_names=('label',), context=ctx)
modDL.bind(data_shapes=[('data', (batch_size,nef * 4,4,4))], ################################################################################################################################ fix 512 here
label_shapes=[('label', (batch_size,nef * 4,4,4))],
inputs_need_grad=True)
modDL.init_params(initializer=mx.init.Normal(0.02))
modDL.init_optimizer(
optimizer='adam',
optimizer_params={
'learning_rate': lr,
'wd': 0.,
'beta1': beta1,
'epsilon': epsilon,
'rescale_grad': (1.0/batch_size)
})
# =============module KL=============
symKL = KLDivergenceLoss()
modKL = mx.mod.Module(symbol=symKL, data_names=('data',), label_names=None, context=ctx)
modKL.bind(data_shapes=[('data', (batch_size*2,Z))],
inputs_need_grad=True)
modKL.init_params(initializer=mx.init.Normal(0.02))
modKL.init_optimizer(
optimizer='adam',
optimizer_params={
'learning_rate': lr,
'wd': 0.,
'beta1': beta1,
'epsilon': epsilon,
'rescale_grad': (1.0/batch_size)
})
mods.append(modKL)
def norm_stat(d):
return mx.nd.norm(d)/np.sqrt(d.size)
mon = mx.mon.Monitor(10, norm_stat, pattern=".*output|d1_backward_data", sort=True)
mon = None
if mon is not None:
for mod in mods:
pass
def facc(label, pred):
'''calculating prediction accuracy
'''
pred = pred.ravel()
label = label.ravel()
return ((pred > 0.5) == label).mean()
def fentropy(label, pred):
'''calculating binary cross-entropy loss
'''
pred = pred.ravel()
label = label.ravel()
return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean()
def kldivergence(label, pred):
'''calculating KL divergence loss
'''
mean, log_var = np.split(pred, 2, axis=0)
var = np.exp(log_var)
KLLoss = -0.5 * np.sum(1 + log_var - np.power(mean, 2) - var)
KLLoss = KLLoss / nElements
return KLLoss
mG = mx.metric.CustomMetric(fentropy)
mD = mx.metric.CustomMetric(fentropy)
mE = mx.metric.CustomMetric(kldivergence)
mACC = mx.metric.CustomMetric(facc)
print('Training...')
stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
# =============train===============
for epoch in range(num_epoch):
train_iter.reset()
for t, batch in enumerate(train_iter):
rbatch = rand_iter.next()
if mon is not None:
mon.tic()
modG.forward(rbatch, is_train=True)
outG = modG.get_outputs()
# update discriminator on fake
label[:] = 0
modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
modD.backward()
gradD11 = [[grad.copyto(grad.context) for grad in grads] for grads in modD1._exec_group.grad_arrays]
gradD12 = [[grad.copyto(grad.context) for grad in grads] for grads in modD2._exec_group.grad_arrays]
modD.update_metric(mD, [label])
modD.update_metric(mACC, [label])
#update discriminator on decoded
modE.forward(batch, is_train=True)
mu, lv, z = modE.get_outputs()
z = z.reshape((batch_size, Z, 1, 1))
sample = mx.io.DataBatch([z], label=None, provide_data = [('rand', (batch_size, Z, 1, 1))])
modG.forward(sample, is_train=True)
xz = modG.get_outputs()
label[:] = 0
modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
modD.backward()
#modD.update()
gradD21 = [[grad.copyto(grad.context) for grad in grads] for grads in modD1._exec_group.grad_arrays]
gradD22 = [[grad.copyto(grad.context) for grad in grads] for grads in modD2._exec_group.grad_arrays]
modD.update_metric(mD, [label])
modD.update_metric(mACC, [label])
# update discriminator on real
label[:] = 1
batch.label = [label]
modD.forward(batch, is_train=True)
lx = [out.copyto(out.context) for out in modD1.get_outputs()]
modD.backward()
for gradsr, gradsf, gradsd in zip(modD1._exec_group.grad_arrays, gradD11, gradD21):
for gradr, gradf, gradd in zip(gradsr, gradsf, gradsd):
gradr += 0.5 * (gradf + gradd)
for gradsr, gradsf, gradsd in zip(modD2._exec_group.grad_arrays, gradD12, gradD22):
for gradr, gradf, gradd in zip(gradsr, gradsf, gradsd):
gradr += 0.5 * (gradf + gradd)
modD.update()
modD.update_metric(mD, [label])
modD.update_metric(mACC, [label])
modG.forward(rbatch, is_train=True)
outG = modG.get_outputs()
label[:] = 1
modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
modD.backward()
diffD = modD1.get_input_grads()
modG.backward(diffD)
gradG1 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
mG.update([label], modD.get_outputs())
modG.forward(sample, is_train=True)
xz = modG.get_outputs()
label[:] = 1
modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
modD.backward()
diffD = modD1.get_input_grads()
modG.backward(diffD)
gradG2 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
mG.update([label], modD.get_outputs())
modG.forward(sample, is_train=True)
xz = modG.get_outputs()
modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
outD1 = modD1.get_outputs()
modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
modDL.backward()
dlGrad = modDL.get_input_grads()
modD1.backward(dlGrad)
diffD = modD1.get_input_grads()
modG.backward(diffD)
for grads, gradsG1, gradsG2 in zip(modG._exec_group.grad_arrays, gradG1, gradG2):
for grad, gradg1, gradg2 in zip(grads, gradsG1, gradsG2):
grad = g_dl_weight * grad + 0.5 * (gradg1 + gradg2)
modG.update()
mG.update([label], modD.get_outputs())
modG.forward(rbatch, is_train=True)
outG = modG.get_outputs()
label[:] = 1
modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
modD.backward()
diffD = modD1.get_input_grads()
modG.backward(diffD)
gradG1 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
mG.update([label], modD.get_outputs())
modG.forward(sample, is_train=True)
xz = modG.get_outputs()
label[:] = 1
modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
modD.backward()
diffD = modD1.get_input_grads()
modG.backward(diffD)
gradG2 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
mG.update([label], modD.get_outputs())
modG.forward(sample, is_train=True)
xz = modG.get_outputs()
modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
outD1 = modD1.get_outputs()
modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
modDL.backward()
dlGrad = modDL.get_input_grads()
modD1.backward(dlGrad)
diffD = modD1.get_input_grads()
modG.backward(diffD)
for grads, gradsG1, gradsG2 in zip(modG._exec_group.grad_arrays, gradG1, gradG2):
for grad, gradg1, gradg2 in zip(grads, gradsG1, gradsG2):
grad = g_dl_weight * grad + 0.5 * (gradg1 + gradg2)
modG.update()
mG.update([label], modD.get_outputs())
modG.forward(sample, is_train=True)
xz = modG.get_outputs()
#update generator
modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
outD1 = modD1.get_outputs()
modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
DLloss = modDL.get_outputs()
modDL.backward()
dlGrad = modDL.get_input_grads()
modD1.backward(dlGrad)
diffD = modD1.get_input_grads()
modG.backward(diffD)
#update encoder
nElements = batch_size
modKL.forward(mx.io.DataBatch([mx.ndarray.concat(mu,lv, dim=0)]), is_train=True)
KLloss = modKL.get_outputs()
modKL.backward()
gradKLLoss = modKL.get_input_grads()
diffG = modG.get_input_grads()
diffG = diffG[0].reshape((batch_size, Z))
modE.backward(mx.ndarray.split(gradKLLoss[0], num_outputs=2, axis=0) + [diffG])
modE.update()
pred = mx.ndarray.concat(mu,lv, dim=0)
mE.update([pred], [pred])
if mon is not None:
mon.toc_print()
t += 1
if t % show_after_every == 0:
print('epoch:', epoch, 'iter:', t, 'metric:', mACC.get(), mG.get(), mD.get(), mE.get(), KLloss[0].asnumpy(), DLloss[0].asnumpy())
mACC.reset()
mG.reset()
mD.reset()
mE.reset()
if epoch % visualize_after_every == 0:
visual(output_path +'gout'+str(epoch), outG[0].asnumpy(), activation)
visual(output_path + 'data'+str(epoch), batch.data[0].asnumpy(), activation)
if check_point and epoch % save_after_every == 0:
print('Saving...')
modG.save_params(checkpoint_path + '/%s_G-%04d.params'%(dataset, epoch))
modD.save_params(checkpoint_path + '/%s_D-%04d.params'%(dataset, epoch))
modE.save_params(checkpoint_path + '/%s_E-%04d.params'%(dataset, epoch))