salina_examples/offline_rl/bc_on_full_episodes/bc.py [91:140]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            )

        batch_size = cfg_algorithm.batch_size
        replay_workspace = buffer.select_batch_n(batch_size).to(
            cfg_algorithm.loss_device
        )
        _st=time.time()
        T = replay_workspace.time_size()
        length = replay_workspace["env/done"].float().argmax(0)
        mask = torch.arange(T).unsqueeze(-1).repeat(1, batch_size).to(length.device)
        length = length.unsqueeze(0).repeat(T, 1)
        mask = mask.le(length).float()
        target_action = replay_workspace["action"].detach()
        action_agent(replay_workspace)
        action = replay_workspace["action"]
        mse = (target_action - action) ** 2
        mse_loss = (mse.sum(2) * mask).sum() / mask.sum()
        logger.add_scalar("loss/mse", mse_loss.item(), epoch)
        optimizer_action.zero_grad()
        mse_loss.backward()
        if cfg_algorithm.clip_grad > 0:
            n = torch.nn.utils.clip_grad_norm_(
                action_agent.parameters(), cfg_algorithm.clip_grad
            )
            logger.add_scalar("monitor/grad_norm", n.item(), epoch)

        optimizer_action.step()
        _et=time.time()
        nsteps=batch_size*T
        nsteps_ps=nsteps/(_et-_st)
        nsteps_ps_cache.append(nsteps_ps)
        if len(nsteps_ps_cache)>1000: nsteps_ps_cache.pop(0)
        logger.add_scalar("monitor/steps_per_seconds", np.mean(nsteps_ps_cache), epoch)

@hydra.main(config_path=".", config_name="gym.yaml")
def main(cfg):

    logger = instantiate_class(cfg.logger)
    logger.save_hps(cfg)
    from importlib import import_module

    env = instantiate_class(cfg.env)
    workspace = salina_examples.offline_rl.d4rl.d4rl_episode_buffer(env)
    agent = instantiate_class(cfg.agent)
    run_bc(workspace, logger, agent, cfg.algorithm, cfg.env)


import os

if __name__ == "__main__":
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



salina_examples/offline_rl/decision_transformer/dt.py [155:207]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            )


        batch_size = cfg_algorithm.batch_size
        replay_workspace = buffer.select_batch_n(batch_size).to(
            cfg_algorithm.loss_device
        )
        _st=time.time()
        T = replay_workspace.time_size()
        length = replay_workspace["env/done"].float().argmax(0)
        mask = torch.arange(T).unsqueeze(-1).repeat(1, batch_size).to(length.device)
        length = length.unsqueeze(0).repeat(T, 1)
        mask = mask.le(length).float()
        target_action = replay_workspace["action"].detach()
        action_agent(replay_workspace)
        action = replay_workspace["action"]
        mse = (target_action - action) ** 2
        mse_loss = (mse.sum(2) * mask).sum() / mask.sum()
        logger.add_scalar("loss/mse", mse_loss.item(), epoch)
        optimizer_action.zero_grad()
        mse_loss.backward()
        if cfg_algorithm.clip_grad > 0:
            n = torch.nn.utils.clip_grad_norm_(
                action_agent.parameters(), cfg_algorithm.clip_grad
            )
            logger.add_scalar("monitor/grad_norm", n.item(), epoch)

        optimizer_action.step()
        _et=time.time()
        nsteps=batch_size*T
        nsteps_ps=nsteps/(_et-_st)
        nsteps_ps_cache.append(nsteps_ps)
        if len(nsteps_ps_cache)>1000: nsteps_ps_cache.pop(0)
        logger.add_scalar("monitor/steps_per_seconds", np.mean(nsteps_ps_cache), epoch)



@hydra.main(config_path=".", config_name="gym.yaml")
def main(cfg):

    logger = instantiate_class(cfg.logger)
    logger.save_hps(cfg)
    from importlib import import_module

    env = instantiate_class(cfg.env)
    workspace = salina_examples.offline_rl.d4rl.d4rl_episode_buffer(env)
    agent = instantiate_class(cfg.agent)
    run_bc(workspace, logger, agent, cfg.algorithm, cfg.env)


import os

if __name__ == "__main__":
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



