def train()

in example/reinforcement-learning/a3c/a3c.py [0:0]


def train():
    # kvstore
    kv = mx.kvstore.create(args.kv_store)

    model_prefix = args.model_prefix
    if model_prefix is not None:
        model_prefix += "-%d" % (kv.rank)
    save_model_prefix = args.save_model_prefix
    if save_model_prefix is None:
        save_model_prefix = model_prefix

    log_config(args.log_dir, args.log_file, save_model_prefix, kv.rank)

    devs = mx.cpu() if args.gpus is None else [
        mx.gpu(int(i)) for i in args.gpus.split(',')]

    epoch_size = args.num_examples / args.batch_size

    if args.kv_store == 'dist_sync':
        epoch_size /= kv.num_workers
    
    # disable kvstore for single device
    if 'local' in kv.type and (
            args.gpus is None or len(args.gpus.split(',')) is 1):
        kv = None

    # module
    dataiter = rl_data.GymDataIter('Breakout-v0', args.batch_size, args.input_length, web_viz=True)
    net = sym.get_symbol_atari(dataiter.act_dim)
    module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs)
    module.bind(data_shapes=dataiter.provide_data,
                label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))],
                grad_req='add')

    # load model

    if args.load_epoch is not None:
        assert model_prefix is not None
        _, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.load_epoch)
    else:
        arg_params = aux_params = None

    # save model
    checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)

    init = mx.init.Mixed(['fc_value_weight|fc_policy_weight', '.*'],
                         [mx.init.Uniform(0.001), mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)])
    module.init_params(initializer=init,
                       arg_params=arg_params, aux_params=aux_params)

    # optimizer
    module.init_optimizer(kvstore=kv, optimizer='adam',
                          optimizer_params={'learning_rate': args.lr, 'wd': args.wd, 'epsilon': 1e-3})

    # logging
    np.set_printoptions(precision=3, suppress=True)

    T = 0
    dataiter.reset()
    score = np.zeros((args.batch_size, 1))
    final_score = np.zeros((args.batch_size, 1))
    for epoch in range(args.num_epochs):
        if save_model_prefix:
            module.save_params('%s-%04d.params'%(save_model_prefix, epoch))


        for _ in range(epoch_size/args.t_max):
            tic = time.time()
            # clear gradients
            for exe in module._exec_group.grad_arrays:
                for g in exe:
                    g[:] = 0

            S, A, V, r, D = [], [], [], [], []
            for t in range(args.t_max + 1):
                data = dataiter.data()
                module.forward(mx.io.DataBatch(data=data, label=None), is_train=False)
                act, _, val = module.get_outputs()
                V.append(val.asnumpy())
                if t < args.t_max:
                    act = act.asnumpy()
                    act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])]
                    reward, done = dataiter.act(act)
                    S.append(data)
                    A.append(act)
                    r.append(reward.reshape((-1, 1)))
                    D.append(done.reshape((-1, 1)))

            err = 0
            R = V[args.t_max]
            for i in reversed(range(args.t_max)):
                R = r[i] + args.gamma * (1 - D[i]) * R
                adv = np.tile(R - V[i], (1, dataiter.act_dim))

                batch = mx.io.DataBatch(data=S[i], label=[mx.nd.array(A[i]), mx.nd.array(R)])
                module.forward(batch, is_train=True)

                pi = module.get_outputs()[1]
                h = -args.beta*(mx.nd.log(pi+1e-7)*pi)
                out_acts = np.amax(pi.asnumpy(), 1)
                out_acts=np.reshape(out_acts,(-1,1))
                out_acts_tile=np.tile(-np.log(out_acts + 1e-7),(1, dataiter.act_dim))
                module.backward([mx.nd.array(out_acts_tile*adv), h])

                print('pi', pi[0].asnumpy())
                print('h', h[0].asnumpy())
                err += (adv**2).mean()
                score += r[i]
                final_score *= (1-D[i]) 
                final_score += score * D[i]
                score *= 1-D[i]
                T += D[i].sum()

            module.update()
            logging.info('fps: %f err: %f score: %f final: %f T: %f'%(args.batch_size/(time.time()-tic), err/args.t_max, score.mean(), final_score.mean(), T))
            print(score.squeeze())
            print(final_score.squeeze())