in pretrain.py [0:0]
def eval_mfdim(setup, n_samples: int) -> Dict[str, float]:
cfg = setup.cfg
agent = setup.agent
rq = setup.rq
envs = setup.eval_envs
n_episodes = cfg.eval.episodes_per_task * len(setup.goal_dims)
task_map_r: Dict[int, int] = {}
for k, v in setup.task_map.items():
task_map_r[v] = int(k)
envs.seed(list(range(envs.num_envs)))
obs = envs.reset()
n_done = 0
reached_goala: Dict[str, List[bool]] = defaultdict(list)
reward = th.zeros(envs.num_envs)
rewards: List[th.Tensor] = []
dones: List[th.Tensor] = [th.tensor([False] * envs.num_envs)]
rq_in: List[List[Dict[str, Any]]] = [[] for _ in range(envs.num_envs)]
n_imgs = 0
collect_img = cfg.eval.video is not None
collect_all = collect_img and cfg.eval.video.record_all
vwidth = int(cfg.eval.video.size[0]) if collect_img else 0
vheight = int(cfg.eval.video.size[1]) if collect_img else 0
while True:
abstractions = []
for i in range(envs.num_envs):
bits = list(th.where(obs['task'][i] == 1)[0].cpu().numpy())
abstractions.append([task_map_r.get(b, b) for b in bits])
if collect_img:
if collect_all:
# TODO This OOMs if we do many evaluations since we record way
# more frames than we need to.
for i, img in enumerate(
envs.render_all(
mode='rgb_array', width=vwidth, height=vheight
)
):
if dones[-1][i].item():
continue
rq_in[i].append(
{
'img': img,
's_left': [
f'Eval',
f'Samples {n_samples}',
],
's_right': [
f'Trial {i+1}',
f'Frame {len(rewards)}',
f'Features {abstractions[i]}',
f'Reward {reward[i].item():+.02f}',
],
}
)
else:
if not dones[-1][0].item():
rq_in[0].append(
{
'img': envs.render_single(
mode='rgb_array', width=vwidth, height=vheight
),
's_left': [
f'Eval',
f'Samples {n_samples}',
],
's_right': [
f'Frame {n_imgs}',
f'Features {abstractions[0]}',
f'Reward {reward[0].item():+.02f}',
],
}
)
n_imgs += 1
if n_imgs > cfg.eval.video.length:
collect_img = False
t_obs = (
th_flatten(envs.observation_space, obs)
if cfg.agent.name != 'sacmt'
else obs
)
action, _ = agent.action(envs, t_obs)
next_obs, reward, done, info = envs.step(action)
soft_reset = th.tensor(['SoftReset' in inf for inf in info])
done = done.view(-1).cpu()
rewards.append(reward.view(-1).to('cpu', copy=True))
dones.append(done | soft_reset)
# Record minimum distance reached for all done environments
for d in th.where(dones[-1] == True)[0].numpy():
key = ','.join([str(a) for a in abstractions[d]])
reached_goala[key].append(info[d]['reached_goal'])
n_done += dones[-1].sum().item()
if n_done >= n_episodes:
break
obs = envs.reset_if_done()
reward = th.stack(rewards, dim=1)
not_done = th.logical_not(th.stack(dones, dim=1))
r_discounted = reward.clone()
discounted_bwd_cumsum_(r_discounted, cfg.agent.gamma, mask=not_done[:, 1:])[
:, 0
]
r_undiscounted = reward.clone()
discounted_bwd_cumsum_(r_undiscounted, 1.0, mask=not_done[:, 1:])[:, 0]
# Gather stats regarding which goals were reached
goals_reached = 0.0
goalsa_reached: Dict[str, float] = defaultdict(float)
for abstr, reached in reached_goala.items():
goalsa_reached[abstr] = th.tensor(reached).sum().item() / len(reached)
goals_reached += goalsa_reached[abstr] * len(reached)
goals_reached /= n_done
goalsa_reached['total'] = goals_reached
if agent.tbw:
agent.tbw_add_scalars('Eval/ReturnDisc', r_discounted)
agent.tbw_add_scalars('Eval/ReturnUndisc', r_undiscounted)
agent.tbw.add_scalars(
'Eval/GoalsReached', goalsa_reached, agent.n_samples
)
agent.tbw.add_scalars(
'Eval/NumTrials',
{a: len(d) for a, d in reached_goala.items()},
agent.n_samples,
)
log.info(
f'eval done, goals reached {goals_reached:.03f}, avg return {r_discounted.mean().item():+.03f}, undisc avg {r_undiscounted.mean():+.03f} min {r_undiscounted.min():+0.3f} max {r_undiscounted.max():+0.3f}'
)
if sum([len(q) for q in rq_in]) > 0:
# Display cumulative reward in video
c_rew = reward * not_done[:, :-1]
for i in range(c_rew.shape[1] - 1):
c_rew[:, i + 1] += c_rew[:, i]
c_rew[:, i + 1] *= not_done[:, i]
n_imgs = 0
for i, ep in enumerate(rq_in):
for j, input in enumerate(ep):
if n_imgs <= cfg.eval.video.length:
input['s_right'].append(f'Acc. Reward {c_rew[i][j]:+.02f}')
rq.push(**input)
n_imgs += 1
rq.plot()
return goalsa_reached