salina_examples/offline_rl/bc_on_full_episodes/bc.py [98:144]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        T = replay_workspace.time_size()
        length = replay_workspace["env/done"].float().argmax(0)
        mask = torch.arange(T).unsqueeze(-1).repeat(1, batch_size).to(length.device)
        length = length.unsqueeze(0).repeat(T, 1)
        mask = mask.le(length).float()
        target_action = replay_workspace["action"].detach()
        action_agent(replay_workspace)
        action = replay_workspace["action"]
        mse = (target_action - action) ** 2
        mse_loss = (mse.sum(2) * mask).sum() / mask.sum()
        logger.add_scalar("loss/mse", mse_loss.item(), epoch)
        optimizer_action.zero_grad()
        mse_loss.backward()
        if cfg_algorithm.clip_grad > 0:
            n = torch.nn.utils.clip_grad_norm_(
                action_agent.parameters(), cfg_algorithm.clip_grad
            )
            logger.add_scalar("monitor/grad_norm", n.item(), epoch)

        optimizer_action.step()
        _et=time.time()
        nsteps=batch_size*T
        nsteps_ps=nsteps/(_et-_st)
        nsteps_ps_cache.append(nsteps_ps)
        if len(nsteps_ps_cache)>1000: nsteps_ps_cache.pop(0)
        logger.add_scalar("monitor/steps_per_seconds", np.mean(nsteps_ps_cache), epoch)

@hydra.main(config_path=".", config_name="gym.yaml")
def main(cfg):

    logger = instantiate_class(cfg.logger)
    logger.save_hps(cfg)
    from importlib import import_module

    env = instantiate_class(cfg.env)
    workspace = salina_examples.offline_rl.d4rl.d4rl_episode_buffer(env)
    agent = instantiate_class(cfg.agent)
    run_bc(workspace, logger, agent, cfg.algorithm, cfg.env)


import os

if __name__ == "__main__":
    import torch.multiprocessing as mp

    mp.set_start_method("spawn")
    main()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



salina_examples/offline_rl/bc_on_full_episodes/bc_with_torch_amp.py [99:146]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            T = replay_workspace.time_size()
            length = replay_workspace["env/done"].float().argmax(0)
            mask = torch.arange(T).unsqueeze(-1).repeat(1, batch_size).to(length.device)
            length = length.unsqueeze(0).repeat(T, 1)
            mask = mask.le(length).float()
            target_action = replay_workspace["action"].detach()
            action_agent(replay_workspace)
            action = replay_workspace["action"]
            mse = (target_action - action) ** 2
            mse_loss = (mse.sum(2) * mask).sum() / mask.sum()
            logger.add_scalar("loss/mse", mse_loss.item(), epoch)

        optimizer_action.zero_grad()
        mse_loss.backward()
        if cfg_algorithm.clip_grad > 0:
            n = torch.nn.utils.clip_grad_norm_(
                action_agent.parameters(), cfg_algorithm.clip_grad
            )
            logger.add_scalar("monitor/grad_norm", n.item(), epoch)

        optimizer_action.step()
        _et=time.time()
        nsteps=batch_size*T
        nsteps_ps=nsteps/(_et-_st)
        nsteps_ps_cache.append(nsteps_ps)
        if len(nsteps_ps_cache)>1000: nsteps_ps_cache.pop(0)
        logger.add_scalar("monitor/steps_per_seconds", np.mean(nsteps_ps_cache), epoch)

@hydra.main(config_path=".", config_name="gym.yaml")
def main(cfg):

    logger = instantiate_class(cfg.logger)
    logger.save_hps(cfg)
    from importlib import import_module

    env = instantiate_class(cfg.env)
    workspace = salina_examples.offline_rl.d4rl.d4rl_episode_buffer(env)
    agent = instantiate_class(cfg.agent)
    run_bc(workspace, logger, agent, cfg.algorithm, cfg.env)


import os

if __name__ == "__main__":
    import torch.multiprocessing as mp

    mp.set_start_method("spawn")
    main()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



