in example/reinforcement-learning/a3c/a3c.py [0:0]
def train():
# kvstore
kv = mx.kvstore.create(args.kv_store)
model_prefix = args.model_prefix
if model_prefix is not None:
model_prefix += "-%d" % (kv.rank)
save_model_prefix = args.save_model_prefix
if save_model_prefix is None:
save_model_prefix = model_prefix
log_config(args.log_dir, args.log_file, save_model_prefix, kv.rank)
devs = mx.cpu() if args.gpus is None else [
mx.gpu(int(i)) for i in args.gpus.split(',')]
epoch_size = args.num_examples / args.batch_size
if args.kv_store == 'dist_sync':
epoch_size /= kv.num_workers
# disable kvstore for single device
if 'local' in kv.type and (
args.gpus is None or len(args.gpus.split(',')) is 1):
kv = None
# module
dataiter = rl_data.GymDataIter('Breakout-v0', args.batch_size, args.input_length, web_viz=True)
net = sym.get_symbol_atari(dataiter.act_dim)
module = mx.mod.Module(net, data_names=[d[0] for d in dataiter.provide_data], label_names=('policy_label', 'value_label'), context=devs)
module.bind(data_shapes=dataiter.provide_data,
label_shapes=[('policy_label', (args.batch_size,)), ('value_label', (args.batch_size, 1))],
grad_req='add')
# load model
if args.load_epoch is not None:
assert model_prefix is not None
_, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.load_epoch)
else:
arg_params = aux_params = None
# save model
checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)
init = mx.init.Mixed(['fc_value_weight|fc_policy_weight', '.*'],
[mx.init.Uniform(0.001), mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)])
module.init_params(initializer=init,
arg_params=arg_params, aux_params=aux_params)
# optimizer
module.init_optimizer(kvstore=kv, optimizer='adam',
optimizer_params={'learning_rate': args.lr, 'wd': args.wd, 'epsilon': 1e-3})
# logging
np.set_printoptions(precision=3, suppress=True)
T = 0
dataiter.reset()
score = np.zeros((args.batch_size, 1))
final_score = np.zeros((args.batch_size, 1))
for epoch in range(args.num_epochs):
if save_model_prefix:
module.save_params('%s-%04d.params'%(save_model_prefix, epoch))
for _ in range(epoch_size/args.t_max):
tic = time.time()
# clear gradients
for exe in module._exec_group.grad_arrays:
for g in exe:
g[:] = 0
S, A, V, r, D = [], [], [], [], []
for t in range(args.t_max + 1):
data = dataiter.data()
module.forward(mx.io.DataBatch(data=data, label=None), is_train=False)
act, _, val = module.get_outputs()
V.append(val.asnumpy())
if t < args.t_max:
act = act.asnumpy()
act = [np.random.choice(dataiter.act_dim, p=act[i]) for i in range(act.shape[0])]
reward, done = dataiter.act(act)
S.append(data)
A.append(act)
r.append(reward.reshape((-1, 1)))
D.append(done.reshape((-1, 1)))
err = 0
R = V[args.t_max]
for i in reversed(range(args.t_max)):
R = r[i] + args.gamma * (1 - D[i]) * R
adv = np.tile(R - V[i], (1, dataiter.act_dim))
batch = mx.io.DataBatch(data=S[i], label=[mx.nd.array(A[i]), mx.nd.array(R)])
module.forward(batch, is_train=True)
pi = module.get_outputs()[1]
h = -args.beta*(mx.nd.log(pi+1e-7)*pi)
out_acts = np.amax(pi.asnumpy(), 1)
out_acts=np.reshape(out_acts,(-1,1))
out_acts_tile=np.tile(-np.log(out_acts + 1e-7),(1, dataiter.act_dim))
module.backward([mx.nd.array(out_acts_tile*adv), h])
print('pi', pi[0].asnumpy())
print('h', h[0].asnumpy())
err += (adv**2).mean()
score += r[i]
final_score *= (1-D[i])
final_score += score * D[i]
score *= 1-D[i]
T += D[i].sum()
module.update()
logging.info('fps: %f err: %f score: %f final: %f T: %f'%(args.batch_size/(time.time()-tic), err/args.t_max, score.mean(), final_score.mean(), T))
print(score.squeeze())
print(final_score.squeeze())