in train.py [0:0]
def train(args, seeds):
global last_checkpoint_time
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda:0" if args.cuda else "cpu")
if 'cuda' in device.type:
print('Using CUDA\n')
torch.set_num_threads(1)
utils.seed(args.seed)
# Configure logging
if args.xpid is None:
args.xpid = "lr-%s" % time.strftime("%Y%m%d-%H%M%S")
log_dir = os.path.expandvars(os.path.expanduser(args.log_dir))
plogger = FileWriter(
xpid=args.xpid, xp_args=args.__dict__, rootdir=log_dir,
seeds=seeds,
)
stdout_logger = HumanOutputFormat(sys.stdout)
checkpointpath = os.path.expandvars(
os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid, "model.tar"))
)
# Configure actor envs
start_level = 0
if args.full_train_distribution:
num_levels = 0
level_sampler_args = None
seeds = None
else:
num_levels = 1
level_sampler_args = dict(
num_actors=args.num_processes,
strategy=args.level_replay_strategy,
replay_schedule=args.level_replay_schedule,
score_transform=args.level_replay_score_transform,
temperature=args.level_replay_temperature,
eps=args.level_replay_eps,
rho=args.level_replay_rho,
nu=args.level_replay_nu,
alpha=args.level_replay_alpha,
staleness_coef=args.staleness_coef,
staleness_transform=args.staleness_transform,
staleness_temperature=args.staleness_temperature
)
envs, level_sampler = make_lr_venv(
num_envs=args.num_processes, env_name=args.env_name,
seeds=seeds, device=device,
num_levels=num_levels, start_level=start_level,
no_ret_normalization=args.no_ret_normalization,
distribution_mode=args.distribution_mode,
paint_vel_info=args.paint_vel_info,
level_sampler_args=level_sampler_args)
is_minigrid = args.env_name.startswith('MiniGrid')
actor_critic = model_for_env_name(args, envs)
actor_critic.to(device)
rollouts = RolloutStorage(args.num_steps, args.num_processes,
envs.observation_space.shape, envs.action_space,
actor_critic.recurrent_hidden_state_size)
batch_size = int(args.num_processes * args.num_steps / args.num_mini_batch)
def checkpoint():
if args.disable_checkpoint:
return
logging.info("Saving checkpoint to %s", checkpointpath)
torch.save(
{
"model_state_dict": actor_critic.state_dict(),
"optimizer_state_dict": agent.optimizer.state_dict(),
"args": vars(args),
},
checkpointpath,
)
agent = algo.PPO(
actor_critic,
args.clip_param,
args.ppo_epoch,
args.num_mini_batch,
args.value_loss_coef,
args.entropy_coef,
lr=args.lr,
eps=args.eps,
max_grad_norm=args.max_grad_norm,
env_name=args.env_name)
level_seeds = torch.zeros(args.num_processes)
if level_sampler:
obs, level_seeds = envs.reset()
else:
obs = envs.reset()
level_seeds = level_seeds.unsqueeze(-1)
rollouts.obs[0].copy_(obs)
rollouts.to(device)
episode_rewards = deque(maxlen=10)
num_updates = int(
args.num_env_steps) // args.num_steps // args.num_processes
timer = timeit.default_timer
update_start_time = timer()
for j in range(num_updates):
actor_critic.train()
for step in range(args.num_steps):
# Sample actions
with torch.no_grad():
obs_id = rollouts.obs[step]
value, action, action_log_dist, recurrent_hidden_states = actor_critic.act(
obs_id, rollouts.recurrent_hidden_states[step], rollouts.masks[step])
action_log_prob = action_log_dist.gather(-1, action)
# Obser reward and next obs
obs, reward, done, infos = envs.step(action)
# Reset all done levels by sampling from level sampler
for i, info in enumerate(infos):
if 'episode' in info.keys():
episode_rewards.append(info['episode']['r'])
if level_sampler:
level_seeds[i][0] = info['level_seed']
# If done then clean the history of observations.
masks = torch.FloatTensor(
[[0.0] if done_ else [1.0] for done_ in done])
bad_masks = torch.FloatTensor(
[[0.0] if 'bad_transition' in info.keys() else [1.0]
for info in infos])
rollouts.insert(
obs, recurrent_hidden_states,
action, action_log_prob, action_log_dist,
value, reward, masks, bad_masks, level_seeds)
with torch.no_grad():
obs_id = rollouts.obs[-1]
next_value = actor_critic.get_value(
obs_id, rollouts.recurrent_hidden_states[-1],
rollouts.masks[-1]).detach()
rollouts.compute_returns(next_value, args.gamma, args.gae_lambda)
# Update level sampler
if level_sampler:
level_sampler.update_with_rollouts(rollouts)
value_loss, action_loss, dist_entropy = agent.update(rollouts)
rollouts.after_update()
if level_sampler:
level_sampler.after_update()
# Log stats every log_interval updates or if it is the last update
if (j % args.log_interval == 0 and len(episode_rewards) > 1) or j == num_updates - 1:
total_num_steps = (j + 1) * args.num_processes * args.num_steps
update_end_time = timer()
num_interval_updates = 1 if j == 0 else args.log_interval
sps = num_interval_updates*(args.num_processes * args.num_steps) / (update_end_time - update_start_time)
update_start_time = update_end_time
logging.info(f"\nUpdate {j} done, {total_num_steps} steps\n ")
logging.info(f"\nEvaluating on {args.num_test_seeds} test levels...\n ")
eval_episode_rewards = evaluate(args, actor_critic, args.num_test_seeds, device)
logging.info(f"\nEvaluating on {args.num_test_seeds} train levels...\n ")
train_eval_episode_rewards = evaluate(args, actor_critic, args.num_test_seeds, device, start_level=0, num_levels=args.num_train_seeds, seeds=seeds)
stats = {
"step": total_num_steps,
"pg_loss": action_loss,
"value_loss": value_loss,
"dist_entropy": dist_entropy,
"train:mean_episode_return": np.mean(episode_rewards),
"train:median_episode_return": np.median(episode_rewards),
"test:mean_episode_return": np.mean(eval_episode_rewards),
"test:median_episode_return": np.median(eval_episode_rewards),
"train_eval:mean_episode_return": np.mean(train_eval_episode_rewards),
"train_eval:median_episode_return": np.median(train_eval_episode_rewards),
"sps": sps,
}
if is_minigrid:
stats["train:success_rate"] = np.mean(np.array(episode_rewards) > 0)
stats["train_eval:success_rate"] = np.mean(np.array(train_eval_episode_rewards) > 0)
stats["test:success_rate"] = np.mean(np.array(eval_episode_rewards) > 0)
if j == num_updates - 1:
logging.info(f"\nLast update: Evaluating on {args.num_test_seeds} test levels...\n ")
final_eval_episode_rewards = evaluate(args, actor_critic, args.final_num_test_seeds, device)
mean_final_eval_episode_rewards = np.mean(final_eval_episode_rewards)
median_final_eval_episide_rewards = np.median(final_eval_episode_rewards)
plogger.log_final_test_eval({
'num_test_seeds': args.final_num_test_seeds,
'mean_episode_return': mean_final_eval_episode_rewards,
'median_episode_return': median_final_eval_episide_rewards
})
plogger.log(stats)
if args.verbose:
stdout_logger.writekvs(stats)
# Log level weights
if level_sampler and j % args.weight_log_interval == 0:
plogger.log_level_weights(level_sampler.sample_weights())
# Checkpoint
timer = timeit.default_timer
if last_checkpoint_time is None:
last_checkpoint_time = timer()
try:
if j == num_updates - 1 or \
(args.save_interval > 0 and timer() - last_checkpoint_time > args.save_interval * 60): # Save every 10 min.
checkpoint()
last_checkpoint_time = timer()
except KeyboardInterrupt:
return