def sac_train()

in salina_cl/algorithms/sac_finetune/sac.py [0:0]


def sac_train(q_agent_1, q_agent_2, action_agent, env_agent,logger, cfg_sac, seed,n_max_interactions,control_env_agent=None):
    time_unit=None
    if cfg_sac.time_limit>0:
        time_unit=compute_time_unit(cfg_ppo.device)
        logger.message("Time unit is "+str(time_unit)+" seconds.")

    action_agent.set_name("action")


    acq_action_agent=copy.deepcopy(action_agent)
    acq_agent = TemporalAgent(Agents(env_agent, acq_action_agent)).to(cfg_sac.acquisition_device)
    acquisition_workspace=Workspace()
    if cfg_sac.n_processes>1:
        acq_agent,acquisition_workspace=NRemoteAgent.create(acq_agent, num_processes=cfg_sac.n_processes, time_size=cfg_sac.n_timesteps, n_steps=1)
    acq_agent.seed(seed)

    if not control_env_agent is None:
        control_action_agent=copy.deepcopy(action_agent)
        control_agent=TemporalAgent(Agents(control_env_agent, EpisodesDone(), control_action_agent)).to(cfg_sac.acquisition_device)
        control_agent.seed(seed)

    # == Setting up the training agents
    q_target_agent_1 = copy.deepcopy(q_agent_1)
    q_target_agent_2 = copy.deepcopy(q_agent_2)
    q_agent_1.to(cfg_sac.learning_device)
    q_agent_2.to(cfg_sac.learning_device)
    q_target_agent_1.to(cfg_sac.learning_device)
    q_target_agent_2.to(cfg_sac.learning_device)
    action_agent.to(cfg_sac.learning_device)

    # == Setting up & initializing the replay buffer for DQN
    replay_buffer = ReplayBuffer(cfg_sac.buffer_size,device=cfg_sac.buffer_device)
    acq_agent.train()
    action_agent.train()

    logger.message("[SAC] Initializing replay buffer")
    acq_agent(
        acquisition_workspace,
        t=0,
        n_steps=cfg_sac.n_timesteps,
    )
    replay_buffer.put(acquisition_workspace, time_size=cfg_sac.buffer_time_size)

    while replay_buffer.size() < cfg_sac.initial_buffer_size:
        acquisition_workspace.copy_n_last_steps(1)
        acq_agent(acquisition_workspace,t=1,n_steps=cfg_sac.n_timesteps - 1)
        acquisition_workspace.zero_grad()
        replay_buffer.put(acquisition_workspace, time_size=cfg_sac.buffer_time_size)

    action_shape=acquisition_workspace["action"].size()[2:]
    _target_entropy = -0.5 * torch.prod(torch.Tensor(action_shape).to(cfg_sac.learning_device)).item()
    _log_alpha = torch.tensor(math.log(cfg_sac.alpha), requires_grad=True, device=cfg_sac.learning_device)

    logger.message("[SAC] Learning")
    n_interactions = 0

    optimizer_args = get_arguments(cfg_sac.optimizer_q)
    optimizer_q_1 = get_class(cfg_sac.optimizer_q)(
        q_agent_1.parameters(), **optimizer_args
    )
    optimizer_q_2 = get_class(cfg_sac.optimizer_q)(
        q_agent_2.parameters(), **optimizer_args
    )

    optimizer_args = get_arguments(cfg_sac.optimizer_policy)
    optimizer_action = get_class(cfg_sac.optimizer_policy)(
        action_agent.parameters(), **optimizer_args
    )

    optimizer_args = get_arguments(cfg_sac.optimizer_alpha)
    optimizer_alpha = get_class(cfg_sac.optimizer_alpha)(
        [_log_alpha], **optimizer_args
    )
    iteration = 0

    epoch=0
    is_training=True
    _training_start_time=time.time()
    while is_training:

        if not control_env_agent is None and epoch%cfg_sac.control_every_n_epochs==0:
            for a in control_agent.get_by_name("action"):
                a.load_state_dict(_state_dict(action_agent, cfg_sac.acquisition_device))

            control_agent.eval()
            w=Workspace()
            control_agent(
                w,
                t=0,
                stop_variable="env/done"
            )
            length=w["env/done"].max(0)[1]
            arange = torch.arange(length.size()[0], device=length.device)
            creward = (
                w["env/cumulated_reward"][length, arange]
                .mean()
                .item()
            )
            logger.add_scalar("validation/reward", creward, epoch)
            print("reward at ",epoch," = ",creward)

        for a in acq_agent.get_by_name("action"):
            a.load_state_dict(_state_dict(action_agent, cfg_sac.acquisition_device))

        acquisition_workspace.copy_n_last_steps(1)
        acquisition_workspace.zero_grad()
        acq_agent(
            acquisition_workspace,
            t=1,
            n_steps=cfg_sac.n_timesteps - 1,
        )
        replay_buffer.put(acquisition_workspace, time_size=cfg_sac.buffer_time_size)
        done, creward = acquisition_workspace["env/done", "env/cumulated_reward"]

        creward = creward[done]
        if creward.size()[0] > 0:
            logger.add_scalar("monitor/reward", creward.mean().item(), epoch)
        logger.add_scalar("monitor/replay_buffer_size", replay_buffer.size(), epoch)

        n_interactions += (
            acquisition_workspace.time_size() - 1
        ) * acquisition_workspace.batch_size()
        logger.add_scalar("monitor/n_interactions", n_interactions, epoch)

        _st_inner_epoch=time.time()
        for inner_epoch in range(cfg_sac.inner_epochs):
            _alpha=_log_alpha.exp().detach()
            __e=time.time()
            batch_size = cfg_sac.batch_size
            _workspace=replay_buffer.get(batch_size)
            replay_workspace = _workspace.to(
                cfg_sac.learning_device
            )
            done, reward = replay_workspace["env/done", "env/reward"]
            not_done=1.0-done.float()
            reward=reward*cfg_sac.reward_scaling

            q_agent_1(replay_workspace)
            q_1 = replay_workspace["q"].squeeze(-1)
            q_agent_2(replay_workspace)
            q_2 = replay_workspace["q"].squeeze(-1)
            replay_workspace.clear("q")

            assert not q_1.eq(q_2).all()
            with torch.no_grad():
                action_agent(replay_workspace)

                q_target_agent_1(replay_workspace)
                q_target_1 = replay_workspace["q"]

                q_target_agent_2(replay_workspace)
                q_target_2 = replay_workspace["q"]

            assert not q_target_1.eq(q_target_2).all()

            q_target = torch.min(q_target_1, q_target_2).squeeze(-1)
            _logp=replay_workspace["action_logprobs"].detach()
            target = (
                reward[1:]
                + cfg_sac.discount_factor
                * not_done[1:]
                * (q_target[1:]-_alpha*_logp[1:].detach())
            )

            td_1 = (q_1[:-1] - target)*not_done[:-1]+0.000001
            td_2 = (q_2[:-1] - target)*not_done[:-1]+0.000001
            error_1 = (td_1 ** 2).sqrt()
            error_2 = (td_2 ** 2).sqrt()

            optimizer_q_1.zero_grad()
            optimizer_q_2.zero_grad()
            error = error_1 + error_2
            loss = error.mean()
            logger.add_scalar("loss/td_loss_1", error_1.mean().item(), iteration)
            logger.add_scalar("loss/td_loss_2", error_2.mean().item(), iteration)

            loss.backward()

            if cfg_sac.clip_grad > 0:
                n = torch.nn.utils.clip_grad_norm_(
                    q_agent_1.parameters(), cfg_sac.clip_grad
                )
                logger.add_scalar("monitor/grad_norm_q_1", n.item(), iteration)
                n = torch.nn.utils.clip_grad_norm_(
                    q_agent_2.parameters(), cfg_sac.clip_grad
                )
                logger.add_scalar("monitor/grad_norm_q_2", n.item(), iteration)

            optimizer_q_1.step()
            optimizer_q_2.step()

            #Actor loss
            done = replay_workspace["env/done"]
            not_done = (1.0-done.float())

            action_agent(replay_workspace,deterministic=False,)

            q_agent_1(replay_workspace)
            q1 = replay_workspace["q"].squeeze(-1)

            q_agent_2(replay_workspace)
            q2 = replay_workspace["q"].squeeze(-1)

            assert not q1.eq(q2).all()
            q = torch.min(q1, q2)

            optimizer_action.zero_grad()
            logp=replay_workspace["action_logprobs"]
            loss_1=(not_done*(_alpha*logp)).mean()
            loss_2=(not_done*(-q)).mean()
            loss =loss_1+loss_2
            loss.backward()

            if "action_std" in list(replay_workspace.keys()):
                _std=replay_workspace["action_std"]
                logger.add_scalar("monitor/action_std",_std.exp().mean().item(),iteration)

            if cfg_sac.clip_grad > 0:
                n = torch.nn.utils.clip_grad_norm_(
                    action_agent.parameters(), cfg_sac.clip_grad
                )
                logger.add_scalar("monitor/grad_norm_action", n.item(), iteration)

            logger.add_scalar("loss/q_loss", loss.item(), iteration)
            logger.add_scalar("loss/q_loss/alpha_term", loss_1.item(), iteration)
            logger.add_scalar("loss/q_loss/q_term", loss_2.item(), iteration)
            optimizer_action.step()

            #Alpha loss
            if (cfg_sac.learning_alpha):
                action_agent(replay_workspace,deterministic=False,)
                logp=replay_workspace["action_logprobs"]
                _alpha=_log_alpha.exp()
                alpha_loss = (_alpha *
                            (-logp - _target_entropy).detach()*not_done).mean()
                logger.add_scalar("loss/alpha_loss", alpha_loss.item(), iteration)


                optimizer_alpha.zero_grad()
                alpha_loss.backward()
                n = torch.nn.utils.clip_grad_norm_(
                    [_log_alpha], cfg_sac.clip_grad
                )
                logger.add_scalar("monitor/grad_norm_alpha", n.item(), iteration)

                optimizer_alpha.step()
            _alpha=_log_alpha.exp().item()
            logger.add_scalar("monitor/alpha", _alpha, iteration)

            tau = cfg_sac.update_target_tau
            soft_update_params(q_agent_1, q_target_agent_1, tau)
            soft_update_params(q_agent_2, q_target_agent_2, tau)

            iteration += 1
        _et_inner_epoch=time.time()
        logger.add_scalar("monitor/epoch_time",_et_inner_epoch-_st_inner_epoch,epoch)
        epoch+=1
        if n_interactions>n_max_interactions:
            logger.message("== Maximum interactions reached")
            is_training=False
        else:
            if cfg_sac.time_limit>0:
                is_training=time.time()-_training_start_time<cfg_sac.time_limit*time_unit

    r={"n_epochs":epoch,"training_time":time.time()-_training_start_time,"n_interactions":n_interactions}
    action_agent.to("cpu")
    q_agent_1.to("cpu")
    q_agent_2.to("cpu")
    if cfg_sac.n_processes>1: acq_agent.close()
    return r,action_agent,q_agent_1,q_agent_2,replay_buffer.to("cpu")