in pretrain.py [0:0]
def train_loop_mfdim_actor(setup: TrainingSetup):
cfg = setup.cfg
agent = setup.agent
queues = setup.queues
rq = setup.rq
envs = setup.envs
model = setup.model
agent.train()
shared_model = deepcopy(model)
shared_model.to('cpu')
# We'll never need gradients for the target network
for param in shared_model.parameters():
param.requires_grad_(False)
param.share_memory_()
envs.call('set_model', shared_model, agent._gamma)
prev_n_updates = agent.n_updates
n_envs = envs.num_envs
cp_path = cfg.checkpoint_path
record_videos = cfg.video is not None
vwidth = int(cfg.video.size[0]) if record_videos else 0
vheight = int(cfg.video.size[1]) if record_videos else 0
max_steps = int(cfg.max_steps)
obs = envs.reset()
n_imgs = 0
collect_img = False
eval_mode = str(cfg.eval_mode)
cperf: Dict[str, float] = {}
running_cperf: Dict[str, float] = defaultdict(float)
while setup.n_samples < max_steps:
log.debug(f'actor loop {setup.n_samples}')
if setup.n_samples % cfg.eval.interval == 0:
# Checkpoint time
try:
log.debug(
f'Checkpointing to {cp_path} after {setup.n_samples} samples'
)
with open(cp_path, 'wb') as f:
agent.save_checkpoint(f)
if cfg.keep_all_checkpoints:
p = Path(cp_path)
cp_unique_path = str(
p.with_name(f'{p.stem}_{setup.n_samples:08d}{p.suffix}')
)
shutil.copy(cp_path, cp_unique_path)
except:
log.exception('Checkpoint saving failed')
est = estimate_ctrlb(setup)
q_cperf = est['q']
r_cperf = est['r']
if eval_mode == 'rollouts' or len(running_cperf) == 0:
agent.eval()
cperf_new = eval_mfdim(setup, setup.n_samples)
agent.train()
if len(running_cperf) == 0:
for k, v in cperf_new.items():
running_cperf[k] = v
del running_cperf['total']
elif eval_mode == 'running_avg':
cperf_new = running_cperf
elif eval_mode == 'q_value':
cperf_new = q_cperf
elif eval_mode == 'reachability':
if not hasattr(model, 'reachability'):
log.warning(
'Reachability evaluations requested but no reachability model present'
)
cperf_new = r_cperf
else:
raise ValueError(f'Unknown evaluation mode {eval_mode}')
# Fixup goal keys to match '+' syntax
run_cperf = copy(running_cperf)
for k in setup.goal_dims.keys():
flat = k.replace('+', ',')
if flat == k:
continue
if flat in cperf_new:
cperf_new[k] = cperf_new[flat]
del cperf_new[flat]
if flat in run_cperf:
run_cperf[k] = run_cperf[flat]
del run_cperf[flat]
if agent.tbw:
agent.tbw.add_scalars(
'Training/GoalsReached', run_cperf, setup.n_samples
)
agent.tbw.add_scalars(
'Training/CtrlbEstimateQ', q_cperf, setup.n_samples
)
agent.tbw.add_scalars(
'Training/CtrlbEstimateR', r_cperf, setup.n_samples
)
try:
p = Path(cp_path)
abs_path = p.with_name(f'{p.stem}_abs.json')
with open(str(abs_path), 'wt') as ft:
json.dump(
{
'task_map': setup.task_map,
'goal_dims': setup.goal_dims,
'cperf': cperf_new,
'cperf_r': r_cperf,
'cperf_q': q_cperf,
'cperf_running': run_cperf,
},
ft,
)
abs_unique_path = p.with_name(
f'{p.stem}_{setup.n_samples:08d}_abs.json'
)
shutil.copy(str(abs_path), str(abs_unique_path))
except:
log.exception('Saving abstraction info failed')
update_fdist(setup, cperf, cperf_new, setup.n_samples)
cperf = copy(cperf_new)
if record_videos and setup.n_samples % cfg.video.interval == 0:
collect_img = True
if collect_img:
rq.push(
img=envs.render_single(
mode='rgb_array', width=vwidth, height=vheight
),
s_left=[
f'Samples {setup.n_samples}',
f'Frame {n_imgs}',
],
s_right=[
'Train',
],
)
n_imgs += 1
if n_imgs > cfg.video.length:
rq.plot()
n_imgs = 0
collect_img = False
t_obs = (
th_flatten(envs.observation_space, obs)
if cfg.agent.name != 'sacmt'
else obs
)
action, extra = agent.action(envs, t_obs)
assert (
extra is None
), "Distributed training doesn't work with extra info from action"
next_obs, reward, done, info = envs.step(action)
t_next_obs = (
th_flatten(envs.observation_space, next_obs)
if cfg.agent.name != 'sacmt'
else next_obs
)
# XXX CPU transfer seems to be necessary :/
nq = len(queues)
ct_obs = {k: v.cpu().chunk(nq) for k, v in t_obs.items()}
c_action = action.cpu().chunk(nq)
ct_next_obs = {k: v.cpu().chunk(nq) for k, v in t_next_obs.items()}
c_done = done.cpu().chunk(nq)
c_reward = reward.cpu().chunk(nq)
pos = 0
for i, queue in enumerate(queues):
log.debug(
f'put {c_action[i].shape[0]} of {action.shape[0]} elems into queue {i}'
)
n = c_action[i].shape[0]
queue.put(
(
{k: v[i] for k, v in ct_obs.items()},
c_action[i],
extra,
(
{k: v[i] for k, v in ct_next_obs.items()},
c_reward[i],
c_done[i],
info[pos : pos + n],
),
)
)
pos += n
agent.step(envs, t_obs, action, extra, (t_next_obs, reward, done, info))
obs = envs.reset_if_done()
setup.n_samples += n_envs
# Maintain running average of controllability during training
for i in range(n_envs):
if info[i].get('LastStepOfTask', False):
feats = info[i]['features']
running_cperf[feats] *= 0.9
if info[i]['reached_goal']:
running_cperf[feats] += 0.1
# Copy model after update
if agent.n_updates != prev_n_updates:
with th.no_grad():
for tp, dp in zip(
shared_model.parameters(), model.parameters()
):
tp.copy_(dp)
prev_n_updates = agent.n_updates
# Final checkpoint & eval time
try:
log.debug(f'Checkpointing to {cp_path} after {setup.n_samples} samples')
with open(cp_path, 'wb') as f:
agent.save_checkpoint(f)
if cfg.keep_all_checkpoints:
p = Path(cp_path)
cp_unique_path = str(
p.with_name(f'{p.stem}_{setup.n_samples:08d}{p.suffix}')
)
shutil.copy(cp_path, cp_unique_path)
except:
log.exception('Checkpoint saving failed')
agent.eval()
eval_cperf = eval_mfdim(setup, setup.n_samples)
agent.train()
est = estimate_ctrlb(setup)
q_cperf = est['q']
r_cperf = est['r']
if eval_mode == 'rollouts':
cperf_new = eval_cperf
elif eval_mode == 'running_avg':
cperf_new = running_cperf
elif eval_mode == 'q_value':
cperf_new = q_cperf
elif eval_mode == 'reachability':
if not hasattr(model, 'reachability'):
log.warning(
'Reachability evaluations requested but no reachability model present'
)
cperf_new = r_cperf
for d in setup.goal_dims:
suffix = ''
if cperf_new[d] >= 1.0 - cfg.ctrl_eps:
suffix = '*'
log.info(
f'Features {abstr_name(cfg, d)} at ctrl {cperf_new[d]:.04f}{suffix}'
)
try:
p = Path(cp_path)
abs_path = p.with_name(f'{p.stem}_abs.json')
with open(str(abs_path), 'wt') as ft:
json.dump(
{
'task_map': setup.task_map,
'goal_dims': setup.goal_dims,
'cperf': cperf_new,
'cperf_eval': eval_cperf,
'cperf_r': r_cperf,
'cperf_q': q_cperf,
'cperf_running': running_cperf,
},
ft,
)
abs_unique_path = p.with_name(
f'{p.stem}_{setup.n_samples:08d}_abs.json'
)
shutil.copy(str(abs_path), str(abs_unique_path))
except:
log.exception('Saving abstraction info failed')