hucc/agents/hsd3.py [760:878]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        d = dict(
            terminal=done,
            step=obs['time'].remainder(self._action_interval).long(),
        )
        for k, v in action_hi.items():
            d[f'action_hi_{k}'] = v
        for k in self._obs_keys:
            d[f'obs_{k}'] = obs_hi[k]
            if k != 'prev_task' and k != 'prev_subgoal':
                d[f'next_obs_{k}'] = next_obs[k]
        d['reward'] = reward

        self._staging.put_row(d)
        self._cur_rewards.append(reward)

        if self._staging.size == self._staging.max:
            self._staging_to_buffer()

        self._n_steps += 1
        self._n_samples += done.nelement()
        self._n_samples_since_update += done.nelement()
        ilv = self._staging.interleave
        if self._buffer.size + self._staging.size - ilv < self._warmup_samples:
            return
        if self._n_samples_since_update >= self._samples_per_update:
            self.update()
            self._cur_rewards.clear()
            self._n_samples_since_update = 0

    def _staging_to_buffer(self):
        ilv = self._staging.interleave
        buf = self._staging
        assert buf._b is not None
        c = self._action_interval
        # Stack at least two transitions because for training the low-level
        # policy we'll need the next high-level action.
        n_stack = max(c, 2)
        batch: Dict[str, th.Tensor] = dict()
        idx = (
            buf.start + th.arange(0, ilv * n_stack, device=buf.device)
        ) % buf.max
        for k in set(buf._b.keys()):
            b = buf._b[k].index_select(0, idx)
            b = b.view((n_stack, ilv) + b.shape[1:]).transpose(0, 1)
            batch[k] = b

        # c = action_freq
        # i = batch['step']
        # Next action at c - i steps further, but we'll take next_obs so
        # access it at c - i - 1
        next_action_hi = (c - 1) - batch['step'][:, 0]
        # If we have a terminal before, use this instead
        terminal = batch['terminal'].clone()
        for j in range(1, c):
            terminal[:, j] |= terminal[:, j - 1]
        first_terminal = c - terminal.sum(dim=1)
        # Lastly, the episode could have ended with a timeout, which we can
        # detect if we took another action_hi (i == 0) prematurely. This will screw
        # up the reward summation, but hopefully it doesn't hurt too much.
        next_real_action_hi = th.zeros_like(next_action_hi) + c
        for j in range(1, c):
            idx = th.where(batch['step'][:, j] == 0)[0]
            next_real_action_hi[idx] = next_real_action_hi[idx].clamp(0, j - 1)
        next_idx = th.min(
            th.min(next_action_hi, first_terminal), next_real_action_hi
        )

        # Sum up discounted rewards until next c - i - 1
        reward = batch['reward'][:, 0].clone()
        for j in range(1, c):
            reward += self._gamma ** j * batch['reward'][:, j] * (next_idx >= j)

        not_done = th.logical_not(dim_select(batch['terminal'], 1, next_idx))
        obs = {k: batch[f'obs_{k}'][:, 0] for k in self._obs_keys}
        obs['time'] = batch['step'][:, 0:1].clone()
        obs_p = {
            k: dim_select(batch[f'next_obs_{k}'], 1, next_idx)
            for k in self._obs_keys
        }
        obs_p['time'] = obs_p['time'].clone().unsqueeze(1)
        obs_p['time'].fill_(0)

        gamma_exp = th.zeros_like(reward) + self._gamma
        gamma_exp.pow_(next_idx + 1)

        db = dict(
            reward=reward,
            not_done=not_done,
            terminal=batch['terminal'][:, 0],
            gamma_exp=gamma_exp,
        )
        db[f'action_hi_{self._dkey}'] = batch[f'action_hi_{self._dkey}'][:, 0]
        db[f'action_hi_{self._ckey}'] = batch[f'action_hi_{self._ckey}'][:, 0]
        for k, v in obs.items():
            db[f'obs_{k}'] = v
        for k, v in obs_p.items():
            db[f'next_obs_{k}'] = v

        self._buffer.put_row(db)

    def _update(self):
        def act_logp_c(obs, mask):
            dist = self._model_pi_c(obs)
            action = dist.rsample()
            if mask is not None:
                log_prob = (dist.log_prob(action) * mask).sum(
                    dim=-1
                ) / mask.sum(dim=-1)
                action = action * mask * self._action_factor_c
            else:
                log_prob = dist.log_prob(action).sum(dim=-1)
                action = action * self._action_factor_c
            return action, log_prob

        def q_target(batch):
            reward = batch['reward']
            not_done = batch['not_done']
            obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
            alpha_c = self._log_alpha_c.detach().exp()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



hucc/agents/hsdb.py [417:535]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        d = dict(
            terminal=done,
            step=obs['time'].remainder(self._action_interval).long(),
        )
        for k, v in action_hi.items():
            d[f'action_hi_{k}'] = v
        for k in self._obs_keys:
            d[f'obs_{k}'] = obs_hi[k]
            if k != 'prev_task' and k != 'prev_subgoal':
                d[f'next_obs_{k}'] = next_obs[k]
        d['reward'] = reward

        self._staging.put_row(d)
        self._cur_rewards.append(reward)

        if self._staging.size == self._staging.max:
            self._staging_to_buffer()

        self._n_steps += 1
        self._n_samples += done.nelement()
        self._n_samples_since_update += done.nelement()
        ilv = self._staging.interleave
        if self._buffer.size + self._staging.size - ilv < self._warmup_samples:
            return
        if self._n_samples_since_update >= self._samples_per_update:
            self.update()
            self._cur_rewards.clear()
            self._n_samples_since_update = 0

    def _staging_to_buffer(self):
        ilv = self._staging.interleave
        buf = self._staging
        assert buf._b is not None
        c = self._action_interval
        # Stack at least two transitions because for training the low-level
        # policy we'll need the next high-level action.
        n_stack = max(c, 2)
        batch: Dict[str, th.Tensor] = dict()
        idx = (
            buf.start + th.arange(0, ilv * n_stack, device=buf.device)
        ) % buf.max
        for k in set(buf._b.keys()):
            b = buf._b[k].index_select(0, idx)
            b = b.view((n_stack, ilv) + b.shape[1:]).transpose(0, 1)
            batch[k] = b

        # c = action_freq
        # i = batch['step']
        # Next action at c - i steps further, but we'll take next_obs so
        # access it at c - i - 1
        next_action_hi = (c - 1) - batch['step'][:, 0]
        # If we have a terminal before, use this instead
        terminal = batch['terminal'].clone()
        for j in range(1, c):
            terminal[:, j] |= terminal[:, j - 1]
        first_terminal = c - terminal.sum(dim=1)
        # Lastly, the episode could have ended with a timeout, which we can
        # detect if we took another action_hi (i == 0) prematurely. This will screw
        # up the reward summation, but hopefully it doesn't hurt too much.
        next_real_action_hi = th.zeros_like(next_action_hi) + c
        for j in range(1, c):
            idx = th.where(batch['step'][:, j] == 0)[0]
            next_real_action_hi[idx] = next_real_action_hi[idx].clamp(0, j - 1)
        next_idx = th.min(
            th.min(next_action_hi, first_terminal), next_real_action_hi
        )

        # Sum up discounted rewards until next c - i - 1
        reward = batch['reward'][:, 0].clone()
        for j in range(1, c):
            reward += self._gamma ** j * batch['reward'][:, j] * (next_idx >= j)

        not_done = th.logical_not(dim_select(batch['terminal'], 1, next_idx))
        obs = {k: batch[f'obs_{k}'][:, 0] for k in self._obs_keys}
        obs['time'] = batch['step'][:, 0:1].clone()
        obs_p = {
            k: dim_select(batch[f'next_obs_{k}'], 1, next_idx)
            for k in self._obs_keys
        }
        obs_p['time'] = obs_p['time'].clone().unsqueeze(1)
        obs_p['time'].fill_(0)

        gamma_exp = th.zeros_like(reward) + self._gamma
        gamma_exp.pow_(next_idx + 1)

        db = dict(
            reward=reward,
            not_done=not_done,
            terminal=batch['terminal'][:, 0],
            gamma_exp=gamma_exp,
        )
        db[f'action_hi_{self._dkey}'] = batch[f'action_hi_{self._dkey}'][:, 0]
        db[f'action_hi_{self._ckey}'] = batch[f'action_hi_{self._ckey}'][:, 0]
        for k, v in obs.items():
            db[f'obs_{k}'] = v
        for k, v in obs_p.items():
            db[f'next_obs_{k}'] = v

        self._buffer.put_row(db)

    def _update(self):
        def act_logp_c(obs, mask):
            dist = self._model_pi_c(obs)
            action = dist.rsample()
            if mask is not None:
                log_prob = (dist.log_prob(action) * mask).sum(
                    dim=-1
                ) / mask.sum(dim=-1)
                action = action * mask * self._action_factor_c
            else:
                log_prob = dist.log_prob(action).sum(dim=-1)
                action = action * self._action_factor_c
            return action, log_prob

        def q_target(batch):
            reward = batch['reward']
            not_done = batch['not_done']
            obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
            alpha_c = self._log_alpha_c.detach().exp()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



