in level_replay/level_sampler.py [0:0]
def _update_with_rollouts(self, rollouts, score_function):
level_seeds = rollouts.level_seeds
policy_logits = rollouts.action_log_dist
done = ~(rollouts.masks > 0)
total_steps, num_actors = policy_logits.shape[:2]
num_decisions = len(policy_logits)
for actor_index in range(num_actors):
done_steps = done[:,actor_index].nonzero()[:total_steps,0]
start_t = 0
for t in done_steps:
if not start_t < total_steps: break
if t == 0: # if t is 0, then this done step caused a full update of previous seed last cycle
continue
seed_t = level_seeds[start_t,actor_index].item()
seed_idx_t = self.seed2index[seed_t]
score_function_kwargs = {}
episode_logits = policy_logits[start_t:t,actor_index]
score_function_kwargs['episode_logits'] = torch.log_softmax(episode_logits, -1)
if self.requires_value_buffers:
score_function_kwargs['returns'] = rollouts.returns[start_t:t,actor_index]
score_function_kwargs['rewards'] = rollouts.rewards[start_t:t,actor_index]
score_function_kwargs['value_preds'] = rollouts.value_preds[start_t:t,actor_index]
score = score_function(**score_function_kwargs)
num_steps = len(episode_logits)
self.update_seed_score(actor_index, seed_idx_t, score, num_steps)
start_t = t.item()
if start_t < total_steps:
seed_t = level_seeds[start_t,actor_index].item()
seed_idx_t = self.seed2index[seed_t]
score_function_kwargs = {}
episode_logits = policy_logits[start_t:,actor_index]
score_function_kwargs['episode_logits'] = torch.log_softmax(episode_logits, -1)
if self.requires_value_buffers:
score_function_kwargs['returns'] = rollouts.returns[start_t:,actor_index]
score_function_kwargs['rewards'] = rollouts.rewards[start_t:,actor_index]
score_function_kwargs['value_preds'] = rollouts.value_preds[start_t:,actor_index]
score = score_function(**score_function_kwargs)
num_steps = len(episode_logits)
self._partial_update_seed_score(actor_index, seed_idx_t, score, num_steps)