def train()

in example/mxnet_adversarial_vae/vaegan_mxnet.py [0:0]


def train(dataset, nef, ndf, ngf, nc, batch_size, Z, lr, beta1, epsilon, ctx, check_point, g_dl_weight, output_path, checkpoint_path, data_path, activation,num_epoch, save_after_every, visualize_after_every, show_after_every):
    '''adversarial training of the VAE
    '''

    #encoder
    z_mu, z_lv, z = encoder(nef, Z, batch_size)
    symE = mx.sym.Group([z_mu, z_lv, z])

    #generator
    symG = generator(ngf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12, z_dim = Z, activation=activation )

    #discriminator
    h  = discriminator1(ndf)
    dloss  = discriminator2(ndf)
    symD1 = h
    symD2 = dloss


    # ==============data==============
    X_train, _ = get_data(data_path, activation)
    train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size, shuffle=True)
    rand_iter = RandIter(batch_size, Z)
    label = mx.nd.zeros((batch_size,), ctx=ctx)

    # =============module E=============
    modE = mx.mod.Module(symbol=symE, data_names=('data',), label_names=None, context=ctx)
    modE.bind(data_shapes=train_iter.provide_data)
    modE.init_params(initializer=mx.init.Normal(0.02))
    modE.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 1e-6,
            'beta1': beta1,
            'epsilon': epsilon,
            'rescale_grad': (1.0/batch_size)
        })
    mods = [modE]

    # =============module G=============
    modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx)
    modG.bind(data_shapes=rand_iter.provide_data, inputs_need_grad=True)
    modG.init_params(initializer=mx.init.Normal(0.02))
    modG.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 1e-6,
            'beta1': beta1,
            'epsilon': epsilon,
        })
    mods.append(modG)

    # =============module D=============
    modD1 = mx.mod.Module(symD1, label_names=[], context=ctx)
    modD2 = mx.mod.Module(symD2, label_names=('label',), context=ctx)
    modD = mx.mod.SequentialModule()
    modD.add(modD1).add(modD2, take_labels=True, auto_wiring=True)
    modD.bind(data_shapes=train_iter.provide_data,
              label_shapes=[('label', (batch_size,))],
              inputs_need_grad=True)
    modD.init_params(initializer=mx.init.Normal(0.02))
    modD.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 1e-3,
            'beta1': beta1,
            'epsilon': epsilon,
            'rescale_grad': (1.0/batch_size)
        })
    mods.append(modD)


    # =============module DL=============
    symDL = DiscriminatorLayerLoss()
    modDL = mx.mod.Module(symbol=symDL, data_names=('data',), label_names=('label',), context=ctx)
    modDL.bind(data_shapes=[('data', (batch_size,nef * 4,4,4))], ################################################################################################################################ fix 512 here
              label_shapes=[('label', (batch_size,nef * 4,4,4))],
              inputs_need_grad=True)
    modDL.init_params(initializer=mx.init.Normal(0.02))
    modDL.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 0.,
            'beta1': beta1,
            'epsilon': epsilon,
            'rescale_grad': (1.0/batch_size)
        })

    # =============module KL=============
    symKL = KLDivergenceLoss()
    modKL = mx.mod.Module(symbol=symKL, data_names=('data',), label_names=None, context=ctx)
    modKL.bind(data_shapes=[('data', (batch_size*2,Z))],
               inputs_need_grad=True)
    modKL.init_params(initializer=mx.init.Normal(0.02))
    modKL.init_optimizer(
        optimizer='adam',
        optimizer_params={
            'learning_rate': lr,
            'wd': 0.,
            'beta1': beta1,
            'epsilon': epsilon,
            'rescale_grad': (1.0/batch_size)
        })
    mods.append(modKL)

    def norm_stat(d):
        return mx.nd.norm(d)/np.sqrt(d.size)
    mon = mx.mon.Monitor(10, norm_stat, pattern=".*output|d1_backward_data", sort=True)
    mon = None
    if mon is not None:
        for mod in mods:
            pass

    def facc(label, pred):
        '''calculating prediction accuracy
        '''
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()

    def fentropy(label, pred):
        '''calculating binary cross-entropy loss
        '''
        pred = pred.ravel()
        label = label.ravel()
        return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean()

    def kldivergence(label, pred):
        '''calculating KL divergence loss
        '''
        mean, log_var = np.split(pred, 2, axis=0)
        var = np.exp(log_var)
        KLLoss = -0.5 * np.sum(1 + log_var - np.power(mean, 2) - var)
        KLLoss = KLLoss / nElements
        return KLLoss

    mG = mx.metric.CustomMetric(fentropy)
    mD = mx.metric.CustomMetric(fentropy)
    mE = mx.metric.CustomMetric(kldivergence)
    mACC = mx.metric.CustomMetric(facc)

    print('Training...')
    stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')

    # =============train===============
    for epoch in range(num_epoch):
        train_iter.reset()
        for t, batch in enumerate(train_iter):

            rbatch = rand_iter.next()

            if mon is not None:
                mon.tic()

            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()

            # update discriminator on fake
            label[:] = 0
            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
            modD.backward()
            gradD11 = [[grad.copyto(grad.context) for grad in grads] for grads in modD1._exec_group.grad_arrays]
            gradD12 = [[grad.copyto(grad.context) for grad in grads] for grads in modD2._exec_group.grad_arrays]

            modD.update_metric(mD, [label])
            modD.update_metric(mACC, [label])


            #update discriminator on decoded
            modE.forward(batch, is_train=True)
            mu, lv, z = modE.get_outputs()
            z = z.reshape((batch_size, Z, 1, 1))
            sample = mx.io.DataBatch([z], label=None, provide_data = [('rand', (batch_size, Z, 1, 1))])
            modG.forward(sample, is_train=True)
            xz = modG.get_outputs()
            label[:] = 0
            modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
            modD.backward()

            #modD.update()
            gradD21 = [[grad.copyto(grad.context) for grad in grads] for grads in modD1._exec_group.grad_arrays]
            gradD22 = [[grad.copyto(grad.context) for grad in grads] for grads in modD2._exec_group.grad_arrays]
            modD.update_metric(mD, [label])
            modD.update_metric(mACC, [label])

            # update discriminator on real
            label[:] = 1
            batch.label = [label]
            modD.forward(batch, is_train=True)
            lx = [out.copyto(out.context) for out in modD1.get_outputs()]
            modD.backward()
            for gradsr, gradsf, gradsd in zip(modD1._exec_group.grad_arrays, gradD11, gradD21):
                for gradr, gradf, gradd in zip(gradsr, gradsf, gradsd):
                    gradr += 0.5 * (gradf + gradd)
            for gradsr, gradsf, gradsd in zip(modD2._exec_group.grad_arrays, gradD12, gradD22):
                for gradr, gradf, gradd in zip(gradsr, gradsf, gradsd):
                    gradr += 0.5 * (gradf + gradd)

            modD.update()
            modD.update_metric(mD, [label])
            modD.update_metric(mACC, [label])

            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()
            label[:] = 1
            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
            modD.backward()
            diffD = modD1.get_input_grads()
            modG.backward(diffD)
            gradG1 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
            mG.update([label], modD.get_outputs())

            modG.forward(sample, is_train=True)
            xz = modG.get_outputs()
            label[:] = 1
            modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
            modD.backward()
            diffD = modD1.get_input_grads()
            modG.backward(diffD)
            gradG2 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
            mG.update([label], modD.get_outputs())

            modG.forward(sample, is_train=True)
            xz = modG.get_outputs()
            modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
            outD1 = modD1.get_outputs()
            modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
            modDL.backward()
            dlGrad = modDL.get_input_grads()
            modD1.backward(dlGrad)
            diffD = modD1.get_input_grads()
            modG.backward(diffD)

            for grads, gradsG1, gradsG2 in zip(modG._exec_group.grad_arrays, gradG1, gradG2):
                for grad, gradg1, gradg2 in zip(grads, gradsG1, gradsG2):
                    grad = g_dl_weight * grad + 0.5 * (gradg1 + gradg2)

            modG.update()
            mG.update([label], modD.get_outputs())

            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()
            label[:] = 1
            modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)
            modD.backward()
            diffD = modD1.get_input_grads()
            modG.backward(diffD)
            gradG1 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
            mG.update([label], modD.get_outputs())

            modG.forward(sample, is_train=True)
            xz = modG.get_outputs()
            label[:] = 1
            modD.forward(mx.io.DataBatch(xz, [label]), is_train=True)
            modD.backward()
            diffD = modD1.get_input_grads()
            modG.backward(diffD)
            gradG2 = [[grad.copyto(grad.context) for grad in grads] for grads in modG._exec_group.grad_arrays]
            mG.update([label], modD.get_outputs())

            modG.forward(sample, is_train=True)
            xz = modG.get_outputs()
            modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
            outD1 = modD1.get_outputs()
            modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
            modDL.backward()
            dlGrad = modDL.get_input_grads()
            modD1.backward(dlGrad)
            diffD = modD1.get_input_grads()
            modG.backward(diffD)

            for grads, gradsG1, gradsG2 in zip(modG._exec_group.grad_arrays, gradG1, gradG2):
                for grad, gradg1, gradg2 in zip(grads, gradsG1, gradsG2):
                    grad = g_dl_weight * grad + 0.5 * (gradg1 + gradg2)

            modG.update()
            mG.update([label], modD.get_outputs())

            modG.forward(sample, is_train=True)
            xz = modG.get_outputs()

            #update generator
            modD1.forward(mx.io.DataBatch(xz, []), is_train=True)
            outD1 = modD1.get_outputs()
            modDL.forward(mx.io.DataBatch(outD1, lx), is_train=True)
            DLloss = modDL.get_outputs()
            modDL.backward()
            dlGrad = modDL.get_input_grads()
            modD1.backward(dlGrad)
            diffD = modD1.get_input_grads()
            modG.backward(diffD)
            #update encoder
            nElements = batch_size
            modKL.forward(mx.io.DataBatch([mx.ndarray.concat(mu,lv, dim=0)]), is_train=True)
            KLloss = modKL.get_outputs()
            modKL.backward()
            gradKLLoss = modKL.get_input_grads()
            diffG = modG.get_input_grads()
            diffG = diffG[0].reshape((batch_size, Z))
            modE.backward(mx.ndarray.split(gradKLLoss[0], num_outputs=2, axis=0) + [diffG])
            modE.update()
            pred = mx.ndarray.concat(mu,lv, dim=0)
            mE.update([pred], [pred])
            if mon is not None:
                mon.toc_print()

            t += 1
            if t % show_after_every == 0:
                print('epoch:', epoch, 'iter:', t, 'metric:', mACC.get(), mG.get(), mD.get(), mE.get(), KLloss[0].asnumpy(), DLloss[0].asnumpy())
                mACC.reset()
                mG.reset()
                mD.reset()
                mE.reset()

            if epoch % visualize_after_every == 0:
                visual(output_path +'gout'+str(epoch), outG[0].asnumpy(), activation)
                visual(output_path + 'data'+str(epoch), batch.data[0].asnumpy(), activation)

        if check_point and epoch % save_after_every == 0:
            print('Saving...')
            modG.save_params(checkpoint_path + '/%s_G-%04d.params'%(dataset, epoch))
            modD.save_params(checkpoint_path + '/%s_D-%04d.params'%(dataset, epoch))
            modE.save_params(checkpoint_path + '/%s_E-%04d.params'%(dataset, epoch))