in baselines/acktr/acktr.py [0:0]
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=100, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):
set_global_seeds(seed)
if network == 'cnn':
network_kwargs['one_dim_bias'] = True
policy = build_policy(env, network, **network_kwargs)
nenvs = env.num_envs
ob_space = env.observation_space
ac_space = env.action_space
make_model = lambda : Model(policy, ob_space, ac_space, nenvs, total_timesteps, nprocs=nprocs, nsteps
=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, vf_fisher_coef=
vf_fisher_coef, lr=lr, max_grad_norm=max_grad_norm, kfac_clip=kfac_clip,
lrschedule=lrschedule, is_async=is_async)
if save_interval and logger.get_dir():
import cloudpickle
with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
fh.write(cloudpickle.dumps(make_model))
model = make_model()
if load_path is not None:
model.load(load_path)
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
epinfobuf = deque(maxlen=100)
nbatch = nenvs*nsteps
tstart = time.time()
coord = tf.train.Coordinator()
if is_async:
enqueue_threads = model.q_runner.create_threads(model.sess, coord=coord, start=True)
else:
enqueue_threads = []
for update in range(1, total_timesteps//nbatch+1):
obs, states, rewards, masks, actions, values, epinfos = runner.run()
epinfobuf.extend(epinfos)
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
model.old_obs = obs
nseconds = time.time()-tstart
fps = int((update*nbatch)/nseconds)
if update % log_interval == 0 or update == 1:
ev = explained_variance(values, rewards)
logger.record_tabular("nupdates", update)
logger.record_tabular("total_timesteps", update*nbatch)
logger.record_tabular("fps", fps)
logger.record_tabular("policy_entropy", float(policy_entropy))
logger.record_tabular("policy_loss", float(policy_loss))
logger.record_tabular("value_loss", float(value_loss))
logger.record_tabular("explained_variance", float(ev))
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
logger.dump_tabular()
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
savepath = osp.join(logger.get_dir(), 'checkpoint%.5i'%update)
print('Saving to', savepath)
model.save(savepath)
coord.request_stop()
coord.join(enqueue_threads)
return model