def train_loop_mfdim_actor()

in pretrain.py [0:0]


def train_loop_mfdim_actor(setup: TrainingSetup):
    cfg = setup.cfg
    agent = setup.agent
    queues = setup.queues
    rq = setup.rq
    envs = setup.envs
    model = setup.model

    agent.train()

    shared_model = deepcopy(model)
    shared_model.to('cpu')
    # We'll never need gradients for the target network
    for param in shared_model.parameters():
        param.requires_grad_(False)
        param.share_memory_()
    envs.call('set_model', shared_model, agent._gamma)
    prev_n_updates = agent.n_updates

    n_envs = envs.num_envs
    cp_path = cfg.checkpoint_path
    record_videos = cfg.video is not None
    vwidth = int(cfg.video.size[0]) if record_videos else 0
    vheight = int(cfg.video.size[1]) if record_videos else 0
    max_steps = int(cfg.max_steps)
    obs = envs.reset()
    n_imgs = 0
    collect_img = False
    eval_mode = str(cfg.eval_mode)
    cperf: Dict[str, float] = {}
    running_cperf: Dict[str, float] = defaultdict(float)
    while setup.n_samples < max_steps:
        log.debug(f'actor loop {setup.n_samples}')
        if setup.n_samples % cfg.eval.interval == 0:
            # Checkpoint time
            try:
                log.debug(
                    f'Checkpointing to {cp_path} after {setup.n_samples} samples'
                )
                with open(cp_path, 'wb') as f:
                    agent.save_checkpoint(f)
                if cfg.keep_all_checkpoints:
                    p = Path(cp_path)
                    cp_unique_path = str(
                        p.with_name(f'{p.stem}_{setup.n_samples:08d}{p.suffix}')
                    )
                    shutil.copy(cp_path, cp_unique_path)
            except:
                log.exception('Checkpoint saving failed')

            est = estimate_ctrlb(setup)
            q_cperf = est['q']
            r_cperf = est['r']

            if eval_mode == 'rollouts' or len(running_cperf) == 0:
                agent.eval()
                cperf_new = eval_mfdim(setup, setup.n_samples)
                agent.train()
                if len(running_cperf) == 0:
                    for k, v in cperf_new.items():
                        running_cperf[k] = v
                    del running_cperf['total']
            elif eval_mode == 'running_avg':
                cperf_new = running_cperf
            elif eval_mode == 'q_value':
                cperf_new = q_cperf
            elif eval_mode == 'reachability':
                if not hasattr(model, 'reachability'):
                    log.warning(
                        'Reachability evaluations requested but no reachability model present'
                    )
                cperf_new = r_cperf
            else:
                raise ValueError(f'Unknown evaluation mode {eval_mode}')

            # Fixup goal keys to match '+' syntax
            run_cperf = copy(running_cperf)
            for k in setup.goal_dims.keys():
                flat = k.replace('+', ',')
                if flat == k:
                    continue
                if flat in cperf_new:
                    cperf_new[k] = cperf_new[flat]
                    del cperf_new[flat]
                if flat in run_cperf:
                    run_cperf[k] = run_cperf[flat]
                    del run_cperf[flat]

            if agent.tbw:
                agent.tbw.add_scalars(
                    'Training/GoalsReached', run_cperf, setup.n_samples
                )
                agent.tbw.add_scalars(
                    'Training/CtrlbEstimateQ', q_cperf, setup.n_samples
                )
                agent.tbw.add_scalars(
                    'Training/CtrlbEstimateR', r_cperf, setup.n_samples
                )

            try:
                p = Path(cp_path)
                abs_path = p.with_name(f'{p.stem}_abs.json')
                with open(str(abs_path), 'wt') as ft:
                    json.dump(
                        {
                            'task_map': setup.task_map,
                            'goal_dims': setup.goal_dims,
                            'cperf': cperf_new,
                            'cperf_r': r_cperf,
                            'cperf_q': q_cperf,
                            'cperf_running': run_cperf,
                        },
                        ft,
                    )
                abs_unique_path = p.with_name(
                    f'{p.stem}_{setup.n_samples:08d}_abs.json'
                )
                shutil.copy(str(abs_path), str(abs_unique_path))
            except:
                log.exception('Saving abstraction info failed')

            update_fdist(setup, cperf, cperf_new, setup.n_samples)
            cperf = copy(cperf_new)

        if record_videos and setup.n_samples % cfg.video.interval == 0:
            collect_img = True
        if collect_img:
            rq.push(
                img=envs.render_single(
                    mode='rgb_array', width=vwidth, height=vheight
                ),
                s_left=[
                    f'Samples {setup.n_samples}',
                    f'Frame {n_imgs}',
                ],
                s_right=[
                    'Train',
                ],
            )
            n_imgs += 1
            if n_imgs > cfg.video.length:
                rq.plot()
                n_imgs = 0
                collect_img = False

        t_obs = (
            th_flatten(envs.observation_space, obs)
            if cfg.agent.name != 'sacmt'
            else obs
        )
        action, extra = agent.action(envs, t_obs)
        assert (
            extra is None
        ), "Distributed training doesn't work with extra info from action"
        next_obs, reward, done, info = envs.step(action)
        t_next_obs = (
            th_flatten(envs.observation_space, next_obs)
            if cfg.agent.name != 'sacmt'
            else next_obs
        )
        # XXX CPU transfer seems to be necessary :/
        nq = len(queues)
        ct_obs = {k: v.cpu().chunk(nq) for k, v in t_obs.items()}
        c_action = action.cpu().chunk(nq)
        ct_next_obs = {k: v.cpu().chunk(nq) for k, v in t_next_obs.items()}
        c_done = done.cpu().chunk(nq)
        c_reward = reward.cpu().chunk(nq)
        pos = 0
        for i, queue in enumerate(queues):
            log.debug(
                f'put {c_action[i].shape[0]} of {action.shape[0]} elems into queue {i}'
            )
            n = c_action[i].shape[0]
            queue.put(
                (
                    {k: v[i] for k, v in ct_obs.items()},
                    c_action[i],
                    extra,
                    (
                        {k: v[i] for k, v in ct_next_obs.items()},
                        c_reward[i],
                        c_done[i],
                        info[pos : pos + n],
                    ),
                )
            )
            pos += n
        agent.step(envs, t_obs, action, extra, (t_next_obs, reward, done, info))
        obs = envs.reset_if_done()
        setup.n_samples += n_envs

        # Maintain running average of controllability during training
        for i in range(n_envs):
            if info[i].get('LastStepOfTask', False):
                feats = info[i]['features']
                running_cperf[feats] *= 0.9
                if info[i]['reached_goal']:
                    running_cperf[feats] += 0.1

        # Copy model after update
        if agent.n_updates != prev_n_updates:
            with th.no_grad():
                for tp, dp in zip(
                    shared_model.parameters(), model.parameters()
                ):
                    tp.copy_(dp)
            prev_n_updates = agent.n_updates

    # Final checkpoint & eval time
    try:
        log.debug(f'Checkpointing to {cp_path} after {setup.n_samples} samples')
        with open(cp_path, 'wb') as f:
            agent.save_checkpoint(f)
        if cfg.keep_all_checkpoints:
            p = Path(cp_path)
            cp_unique_path = str(
                p.with_name(f'{p.stem}_{setup.n_samples:08d}{p.suffix}')
            )
            shutil.copy(cp_path, cp_unique_path)
    except:
        log.exception('Checkpoint saving failed')

    agent.eval()
    eval_cperf = eval_mfdim(setup, setup.n_samples)
    agent.train()
    est = estimate_ctrlb(setup)
    q_cperf = est['q']
    r_cperf = est['r']
    if eval_mode == 'rollouts':
        cperf_new = eval_cperf
    elif eval_mode == 'running_avg':
        cperf_new = running_cperf
    elif eval_mode == 'q_value':
        cperf_new = q_cperf
    elif eval_mode == 'reachability':
        if not hasattr(model, 'reachability'):
            log.warning(
                'Reachability evaluations requested but no reachability model present'
            )
        cperf_new = r_cperf
    for d in setup.goal_dims:
        suffix = ''
        if cperf_new[d] >= 1.0 - cfg.ctrl_eps:
            suffix = '*'
        log.info(
            f'Features {abstr_name(cfg, d)} at ctrl {cperf_new[d]:.04f}{suffix}'
        )

    try:
        p = Path(cp_path)
        abs_path = p.with_name(f'{p.stem}_abs.json')
        with open(str(abs_path), 'wt') as ft:
            json.dump(
                {
                    'task_map': setup.task_map,
                    'goal_dims': setup.goal_dims,
                    'cperf': cperf_new,
                    'cperf_eval': eval_cperf,
                    'cperf_r': r_cperf,
                    'cperf_q': q_cperf,
                    'cperf_running': running_cperf,
                },
                ft,
            )
        abs_unique_path = p.with_name(
            f'{p.stem}_{setup.n_samples:08d}_abs.json'
        )
        shutil.copy(str(abs_path), str(abs_unique_path))
    except:
        log.exception('Saving abstraction info failed')