def _update_with_rollouts()

in level_replay/level_sampler.py [0:0]


    def _update_with_rollouts(self, rollouts, score_function):
        level_seeds = rollouts.level_seeds
        policy_logits = rollouts.action_log_dist
        done = ~(rollouts.masks > 0)
        total_steps, num_actors = policy_logits.shape[:2]
        num_decisions = len(policy_logits)

        for actor_index in range(num_actors):
            done_steps = done[:,actor_index].nonzero()[:total_steps,0]
            start_t = 0

            for t in done_steps:
                if not start_t < total_steps: break

                if t == 0: # if t is 0, then this done step caused a full update of previous seed last cycle
                    continue 

                seed_t = level_seeds[start_t,actor_index].item()
                seed_idx_t = self.seed2index[seed_t]

                score_function_kwargs = {}
                episode_logits = policy_logits[start_t:t,actor_index]
                score_function_kwargs['episode_logits'] = torch.log_softmax(episode_logits, -1)

                if self.requires_value_buffers:
                    score_function_kwargs['returns'] = rollouts.returns[start_t:t,actor_index]
                    score_function_kwargs['rewards'] = rollouts.rewards[start_t:t,actor_index]
                    score_function_kwargs['value_preds'] = rollouts.value_preds[start_t:t,actor_index]

                score = score_function(**score_function_kwargs)
                num_steps = len(episode_logits)
                self.update_seed_score(actor_index, seed_idx_t, score, num_steps)

                start_t = t.item()

            if start_t < total_steps:
                seed_t = level_seeds[start_t,actor_index].item()
                seed_idx_t = self.seed2index[seed_t]

                score_function_kwargs = {}
                episode_logits = policy_logits[start_t:,actor_index]
                score_function_kwargs['episode_logits'] = torch.log_softmax(episode_logits, -1)

                if self.requires_value_buffers:
                    score_function_kwargs['returns'] = rollouts.returns[start_t:,actor_index]
                    score_function_kwargs['rewards'] = rollouts.rewards[start_t:,actor_index]
                    score_function_kwargs['value_preds'] = rollouts.value_preds[start_t:,actor_index]

                score = score_function(**score_function_kwargs)
                num_steps = len(episode_logits)
                self._partial_update_seed_score(actor_index, seed_idx_t, score, num_steps)