def estimate_ctrlb()

in pretrain.py [0:0]


def estimate_ctrlb(setup: TrainingSetup) -> Dict[str, Dict[str, float]]:
    agent = cast(SACMTAgent, setup.agent)
    model = setup.model
    cfg = setup.cfg

    buffer = agent._buffer
    if buffer.size < agent._warmup_samples or buffer._b is None:
        k = setup.goal_dims.keys()
        return {'q': {d: 0.0 for d in k}, 'r': {d: 0.0 for d in k}}

    entry_point = gym.envs.registry.spec(cfg.env.name).entry_point
    mod_name, attr_name = entry_point.split(":")
    mod = importlib.import_module(mod_name)
    env_cls = getattr(mod, attr_name)
    gsdim = buffer._b['gs_observation'].shape[1]
    psi, offset = env_cls.abstraction_matrix(cfg.robot, cfg.features, gsdim)
    delta_feats = env_cls.delta_features(cfg.robot, cfg.features)
    psi_1 = np.linalg.inv(psi)
    offset_1 = -np.matmul(offset, psi_1)
    psi = th.tensor(psi, dtype=th.float32, device=cfg.device)
    offset = th.tensor(offset, dtype=th.float32, device=cfg.device)
    psi_1 = th.tensor(psi_1, dtype=th.float32, device=cfg.device)
    offset_1 = th.tensor(offset_1, dtype=th.float32, device=cfg.device)
    task_map = setup.task_map
    task_idx = [0] * len(task_map)
    for k, v in task_map.items():
        task_idx[v] = int(k)
    dscale = agent._gamma ** cfg.horizon
    ctrl_cost = (
        cfg.horizon
        * cfg.env.args.ctrl_cost
        * 0.25
        * setup.envs.action_space.shape[0]
    )

    n = 1024
    cperf: Dict[str, Dict[str, float]] = {'q': {}, 'r': {}}
    starts = th.where(buffer._b['start_state'] == True)[0]
    for d in setup.goal_dims.keys():
        #  Query start states from replay buffer
        idx = th.randint(low=0, high=starts.shape[0], size=(n,))
        obs = buffer._b['obs_observation'][starts[idx]].to(cfg.device)

        # Sample goals and project to input space
        # XXX assumes we train with backprojecting goals
        feats = list(map(int, d.replace('+', ',').split(',')))

        if len(feats) > 1 and cfg.estimate_joint_spaces == 'gmm':
            sidx = th.randint(low=0, high=buffer.size, size=(n * 10,))
            sample = (
                th.bmm(
                    buffer._b['gs_observation'][sidx].unsqueeze(1),
                    psi[feats]
                    .T.unsqueeze(0)
                    .expand(sidx.shape[0], gsdim, len(feats)),
                ).squeeze(1)
                + offset[feats]
            )
            clf = GaussianMixture(
                n_components=32, max_iter=100, n_init=10, covariance_type='full'
            )
            clf.fit(sample.cpu())
            wgoal = th.tensor(
                clf.sample(n)[0].clip(-1, 1),
                device=obs.device,
                dtype=th.float32,
            )
        elif len(feats) > 1 and cfg.estimate_joint_spaces == 'kmeans':
            sidx = th.randint(low=0, high=buffer.size, size=(n * 10,))
            sample = (
                th.bmm(
                    buffer._b['gs_observation'][sidx].unsqueeze(1),
                    psi[feats]
                    .T.unsqueeze(0)
                    .expand(sidx.shape[0], gsdim, len(feats)),
                ).squeeze(1)
                + offset[feats]
            )
            clf = KMeans(n_clusters=n)
            clf.fit(sample.cpu())
            wgoal = th.tensor(
                clf.cluster_centers_.clip(-1, 1),
                device=obs.device,
                dtype=th.float32,
            )
        else:
            wgoal = th.rand(size=(n, len(feats)), device=obs.device) * 2 - 1

        gsobs = buffer._b['gs_observation'][starts[idx]].to(cfg.device)
        ws = (
            th.bmm(
                gsobs.unsqueeze(1),
                psi[feats].T.unsqueeze(0).expand(n, gsdim, len(feats)),
            ).squeeze(1)
            + offset[feats]
        )
        for i, f in enumerate(feats):
            if f in delta_feats:
                wgoal[:, i] += ws[:, i]
        s = gsobs[:, task_idx]
        gb = (
            th.bmm(
                wgoal.unsqueeze(1),
                psi_1[feats].unsqueeze(0).expand(n, len(feats), gsdim),
            ).squeeze(1)
            + offset_1
        )[:, task_idx]
        feature_mask = th.zeros(len(task_idx), device=obs.device)
        for f in d.replace('+', ',').split(','):
            feature_mask[setup.task_map[f]] = 1
        goal = (gb - s) * feature_mask

        # Record distances in goal space
        wobs = (
            th.bmm(
                gsobs.unsqueeze(1),
                psi[feats].T.unsqueeze(0).expand(n, gsdim, len(feats)),
            ).squeeze(1)
            + offset[feats]
        )
        dist = th.linalg.norm(wgoal - wobs, ord=2, dim=1)

        # Finally, the bow task input
        task = th.zeros(len(task_map), device=cfg.device)
        for f in feats:
            task[task_map[str(f)]] = 1
        task = task.unsqueeze(0).expand(n, len(task_map))

        # Query mean action and corresponding Q-value
        with th.no_grad():
            action = model.pi(
                {
                    'observation': obs,
                    'task': task,
                    'desired_goal': goal,
                }
            ).mean
            q = model.q(
                {
                    'observation': obs,
                    'task': task,
                    'desired_goal': goal,
                    'action': action,
                }
            )
            q1 = q[:, 0]
            q2 = q[:, 1]
            r = model.reachability(
                {
                    'observation': obs,
                    'task': task,
                    'desired_goal': goal,
                    'action': action,
                }
            )
        q = th.min(q1, q2).view(-1)
        cperf['q'][d] = (q >= (dist - ctrl_cost) * dscale).sum().item() / n
        cperf['r'][d] = r.clamp(0, 1).mean().item()

    return cperf