in salina_examples/offline_rl/decision_transformer/dt.py [0:0]
def run_bc(buffer, logger, action_agent, cfg_algorithm, cfg_env):
action_agent.set_name("action_agent")
env = instantiate_class(cfg_env)
print("Computing normalized reward to go...")
rtg_agent = RewardToGoAgent()
rtg_agent(buffer)
# Get normalized reward to go
rtg = buffer["reward_to_go"]
env_name = cfg_env.env_name
rtg = rtg / cfg_algorithm.reward_scale
buffer.set_full("reward_to_go", rtg)
length = buffer["env/done"].float().argmax(0)
control_agent = ControlAgent(cfg_algorithm.reward_scale)
env_evaluation_agent = GymAgent(
get_class(cfg_env),
get_arguments(cfg_env),
n_envs=int(
cfg_algorithm.evaluation.n_envs / cfg_algorithm.evaluation.n_processes
),
)
evaluation_rtg = cfg_algorithm.target_rewards
print("Evaluation target reward: ", evaluation_rtg)
evaluation_position = 0
action_evaluation_agent = copy.deepcopy(action_agent)
action_agent.to(cfg_algorithm.loss_device)
evaluation_agent, evaluation_workspace = NRemoteAgent.create(
TemporalAgent(
Agents(env_evaluation_agent, control_agent, action_evaluation_agent)
),
num_processes=cfg_algorithm.evaluation.n_processes,
t=0,
n_steps=10,
epsilon=0.0,
time_size=cfg_env.max_episode_steps + 1,
control_variable="control_rtg",
control_value=evaluation_rtg[evaluation_position],
)
evaluation_agent.eval()
evaluation_agent.seed(cfg_algorithm.evaluation.env_seed)
evaluation_agent._asynchronous_call(
evaluation_workspace,
t=0,
stop_variable="env/done",
control_variable="control_rtg",
control_value=evaluation_rtg[evaluation_position],
)
logger.message("Learning")
optimizer_args = get_arguments(cfg_algorithm.optimizer)
optimizer_action = get_class(cfg_algorithm.optimizer)(
action_agent.parameters(), **optimizer_args
)
nsteps_ps_cache=[]
for epoch in range(cfg_algorithm.max_epoch):
if not evaluation_agent.is_running():
rtg = evaluation_rtg[evaluation_position]
length = evaluation_workspace["env/done"].float().argmax(0)
creward = evaluation_workspace["env/cumulated_reward"]
crtg = evaluation_workspace["control_rtg"]
l = (length[0] + 1).item()
arange = torch.arange(length.size()[0], device=length.device)
creward = creward[length, arange]
if creward.size()[0] > 0:
logger.add_scalar(
"evaluation/reward/" + str(rtg), creward.mean().item(), epoch
)
v = []
for i in range(creward.size()[0]):
v.append(env.get_normalized_score(creward[i].item()))
logger.add_scalar(
"evaluation/normalized_reward/" + str(rtg), np.mean(v), epoch
)
for a in evaluation_agent.get_by_name("action_agent"):
a.load_state_dict(_state_dict(action_agent, "cpu"))
evaluation_position += 1
if evaluation_position >= len(evaluation_rtg):
evaluation_position = 0
evaluation_workspace.copy_n_last_steps(1)
print("[EVALUATION] Launching for ", evaluation_rtg[evaluation_position])
evaluation_agent._asynchronous_call(
evaluation_workspace,
t=0,
stop_variable="env/done",
epsilon=0.0,
control_variable="control_rtg",
control_value=evaluation_rtg[evaluation_position],
)
batch_size = cfg_algorithm.batch_size
replay_workspace = buffer.select_batch_n(batch_size).to(
cfg_algorithm.loss_device
)
_st=time.time()
T = replay_workspace.time_size()
length = replay_workspace["env/done"].float().argmax(0)
mask = torch.arange(T).unsqueeze(-1).repeat(1, batch_size).to(length.device)
length = length.unsqueeze(0).repeat(T, 1)
mask = mask.le(length).float()
target_action = replay_workspace["action"].detach()
action_agent(replay_workspace)
action = replay_workspace["action"]
mse = (target_action - action) ** 2
mse_loss = (mse.sum(2) * mask).sum() / mask.sum()
logger.add_scalar("loss/mse", mse_loss.item(), epoch)
optimizer_action.zero_grad()
mse_loss.backward()
if cfg_algorithm.clip_grad > 0:
n = torch.nn.utils.clip_grad_norm_(
action_agent.parameters(), cfg_algorithm.clip_grad
)
logger.add_scalar("monitor/grad_norm", n.item(), epoch)
optimizer_action.step()
_et=time.time()
nsteps=batch_size*T
nsteps_ps=nsteps/(_et-_st)
nsteps_ps_cache.append(nsteps_ps)
if len(nsteps_ps_cache)>1000: nsteps_ps_cache.pop(0)
logger.add_scalar("monitor/steps_per_seconds", np.mean(nsteps_ps_cache), epoch)