def run_bc()

in salina_examples/offline_rl/decision_transformer/dt.py [0:0]


def run_bc(buffer, logger, action_agent, cfg_algorithm, cfg_env):
    action_agent.set_name("action_agent")

    env = instantiate_class(cfg_env)

    print("Computing normalized reward to go...")
    rtg_agent = RewardToGoAgent()
    rtg_agent(buffer)

    # Get normalized reward to go
    rtg = buffer["reward_to_go"]
    env_name = cfg_env.env_name

    rtg = rtg / cfg_algorithm.reward_scale
    buffer.set_full("reward_to_go", rtg)

    length = buffer["env/done"].float().argmax(0)

    control_agent = ControlAgent(cfg_algorithm.reward_scale)

    env_evaluation_agent = GymAgent(
        get_class(cfg_env),
        get_arguments(cfg_env),
        n_envs=int(
            cfg_algorithm.evaluation.n_envs / cfg_algorithm.evaluation.n_processes
        ),
    )
    evaluation_rtg = cfg_algorithm.target_rewards
    print("Evaluation target reward: ", evaluation_rtg)
    evaluation_position = 0
    action_evaluation_agent = copy.deepcopy(action_agent)
    action_agent.to(cfg_algorithm.loss_device)
    evaluation_agent, evaluation_workspace = NRemoteAgent.create(
        TemporalAgent(
            Agents(env_evaluation_agent, control_agent, action_evaluation_agent)
        ),
        num_processes=cfg_algorithm.evaluation.n_processes,
        t=0,
        n_steps=10,
        epsilon=0.0,
        time_size=cfg_env.max_episode_steps + 1,
        control_variable="control_rtg",
        control_value=evaluation_rtg[evaluation_position],
    )
    evaluation_agent.eval()

    evaluation_agent.seed(cfg_algorithm.evaluation.env_seed)
    evaluation_agent._asynchronous_call(
        evaluation_workspace,
        t=0,
        stop_variable="env/done",
        control_variable="control_rtg",
        control_value=evaluation_rtg[evaluation_position],
    )

    logger.message("Learning")
    optimizer_args = get_arguments(cfg_algorithm.optimizer)
    optimizer_action = get_class(cfg_algorithm.optimizer)(
        action_agent.parameters(), **optimizer_args
    )
    nsteps_ps_cache=[]
    for epoch in range(cfg_algorithm.max_epoch):
        if not evaluation_agent.is_running():
            rtg = evaluation_rtg[evaluation_position]
            length = evaluation_workspace["env/done"].float().argmax(0)
            creward = evaluation_workspace["env/cumulated_reward"]
            crtg = evaluation_workspace["control_rtg"]
            l = (length[0] + 1).item()

            arange = torch.arange(length.size()[0], device=length.device)
            creward = creward[length, arange]
            if creward.size()[0] > 0:
                logger.add_scalar(
                    "evaluation/reward/" + str(rtg), creward.mean().item(), epoch
                )
                v = []
                for i in range(creward.size()[0]):
                    v.append(env.get_normalized_score(creward[i].item()))
                logger.add_scalar(
                    "evaluation/normalized_reward/" + str(rtg), np.mean(v), epoch
                )
            for a in evaluation_agent.get_by_name("action_agent"):
                a.load_state_dict(_state_dict(action_agent, "cpu"))
            evaluation_position += 1
            if evaluation_position >= len(evaluation_rtg):
                evaluation_position = 0
            evaluation_workspace.copy_n_last_steps(1)
            print("[EVALUATION] Launching for ", evaluation_rtg[evaluation_position])
            evaluation_agent._asynchronous_call(
                evaluation_workspace,
                t=0,
                stop_variable="env/done",
                epsilon=0.0,
                control_variable="control_rtg",
                control_value=evaluation_rtg[evaluation_position],
            )


        batch_size = cfg_algorithm.batch_size
        replay_workspace = buffer.select_batch_n(batch_size).to(
            cfg_algorithm.loss_device
        )
        _st=time.time()
        T = replay_workspace.time_size()
        length = replay_workspace["env/done"].float().argmax(0)
        mask = torch.arange(T).unsqueeze(-1).repeat(1, batch_size).to(length.device)
        length = length.unsqueeze(0).repeat(T, 1)
        mask = mask.le(length).float()
        target_action = replay_workspace["action"].detach()
        action_agent(replay_workspace)
        action = replay_workspace["action"]
        mse = (target_action - action) ** 2
        mse_loss = (mse.sum(2) * mask).sum() / mask.sum()
        logger.add_scalar("loss/mse", mse_loss.item(), epoch)
        optimizer_action.zero_grad()
        mse_loss.backward()
        if cfg_algorithm.clip_grad > 0:
            n = torch.nn.utils.clip_grad_norm_(
                action_agent.parameters(), cfg_algorithm.clip_grad
            )
            logger.add_scalar("monitor/grad_norm", n.item(), epoch)

        optimizer_action.step()
        _et=time.time()
        nsteps=batch_size*T
        nsteps_ps=nsteps/(_et-_st)
        nsteps_ps_cache.append(nsteps_ps)
        if len(nsteps_ps_cache)>1000: nsteps_ps_cache.pop(0)
        logger.add_scalar("monitor/steps_per_seconds", np.mean(nsteps_ps_cache), epoch)