in src/algos/curiosity.py [0:0]
def learn(actor_model,
model,
state_embedding_model,
forward_dynamics_model,
inverse_dynamics_model,
batch,
initial_agent_state,
optimizer,
state_embedding_optimizer,
forward_dynamics_optimizer,
inverse_dynamics_optimizer,
scheduler,
flags,
frames=None,
lock=threading.Lock()):
"""Performs a learning (optimization) step."""
with lock:
if flags.use_fullobs_intrinsic:
state_emb = state_embedding_model(batch, next_state=False)\
.reshape(flags.unroll_length, flags.batch_size, 128)
next_state_emb = state_embedding_model(batch, next_state=True)\
.reshape(flags.unroll_length, flags.batch_size, 128)
else:
state_emb = state_embedding_model(batch['partial_obs'][:-1].to(device=flags.device))
next_state_emb = state_embedding_model(batch['partial_obs'][1:].to(device=flags.device))
pred_next_state_emb = forward_dynamics_model(\
state_emb, batch['action'][1:].to(device=flags.device))
pred_actions = inverse_dynamics_model(state_emb, next_state_emb)
entropy_emb_actions = losses.compute_entropy_loss(pred_actions)
intrinsic_rewards = torch.norm(pred_next_state_emb - next_state_emb, dim=2, p=2)
intrinsic_reward_coef = flags.intrinsic_reward_coef
intrinsic_rewards *= intrinsic_reward_coef
forward_dynamics_loss = flags.forward_loss_coef * \
losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb)
inverse_dynamics_loss = flags.inverse_loss_coef * \
losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:])
num_samples = flags.unroll_length * flags.batch_size
actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy()
intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy()
learner_outputs, unused_state = model(batch, initial_agent_state)
bootstrap_value = learner_outputs['baseline'][-1]
batch = {key: tensor[1:] for key, tensor in batch.items()}
learner_outputs = {
key: tensor[:-1]
for key, tensor in learner_outputs.items()
}
actions = batch['action'].reshape(flags.unroll_length * flags.batch_size).cpu().numpy()
action_percentage = [0 for _ in range(model.num_actions)]
for i in range(model.num_actions):
action_percentage[i] = np.sum([a == i for a in actions]) / len(actions)
rewards = batch['reward']
if flags.no_reward:
total_rewards = intrinsic_rewards
else:
total_rewards = rewards + intrinsic_rewards
clipped_rewards = torch.clamp(total_rewards, -1, 1)
discounts = (~batch['done']).float() * flags.discounting
vtrace_returns = vtrace.from_logits(
behavior_policy_logits=batch['policy_logits'],
target_policy_logits=learner_outputs['policy_logits'],
actions=batch['action'],
discounts=discounts,
rewards=clipped_rewards,
values=learner_outputs['baseline'],
bootstrap_value=bootstrap_value)
pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
batch['action'],
vtrace_returns.pg_advantages)
baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
vtrace_returns.vs - learner_outputs['baseline'])
entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
learner_outputs['policy_logits'])
total_loss = pg_loss + baseline_loss + entropy_loss \
+ forward_dynamics_loss + inverse_dynamics_loss
episode_returns = batch['episode_return'][batch['done']]
episode_lengths = batch['episode_step'][batch['done']]
episode_wins = batch['episode_win'][batch['done']]
stats = {
'mean_episode_return': torch.mean(episode_returns).item(),
'total_loss': total_loss.item(),
'pg_loss': pg_loss.item(),
'baseline_loss': baseline_loss.item(),
'entropy_loss': entropy_loss.item(),
'forward_dynamics_loss': forward_dynamics_loss.item(),
'inverse_dynamics_loss': inverse_dynamics_loss.item(),
'mean_rewards': torch.mean(rewards).item(),
'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
'mean_total_rewards': torch.mean(total_rewards).item(),
}
scheduler.step()
optimizer.zero_grad()
state_embedding_optimizer.zero_grad()
forward_dynamics_optimizer.zero_grad()
inverse_dynamics_optimizer.zero_grad()
total_loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm)
nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm)
nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm)
optimizer.step()
state_embedding_optimizer.step()
forward_dynamics_optimizer.step()
inverse_dynamics_optimizer.step()
actor_model.load_state_dict(model.state_dict())
return stats