in pretrain.py [0:0]
def estimate_ctrlb(setup: TrainingSetup) -> Dict[str, Dict[str, float]]:
agent = cast(SACMTAgent, setup.agent)
model = setup.model
cfg = setup.cfg
buffer = agent._buffer
if buffer.size < agent._warmup_samples or buffer._b is None:
k = setup.goal_dims.keys()
return {'q': {d: 0.0 for d in k}, 'r': {d: 0.0 for d in k}}
entry_point = gym.envs.registry.spec(cfg.env.name).entry_point
mod_name, attr_name = entry_point.split(":")
mod = importlib.import_module(mod_name)
env_cls = getattr(mod, attr_name)
gsdim = buffer._b['gs_observation'].shape[1]
psi, offset = env_cls.abstraction_matrix(cfg.robot, cfg.features, gsdim)
delta_feats = env_cls.delta_features(cfg.robot, cfg.features)
psi_1 = np.linalg.inv(psi)
offset_1 = -np.matmul(offset, psi_1)
psi = th.tensor(psi, dtype=th.float32, device=cfg.device)
offset = th.tensor(offset, dtype=th.float32, device=cfg.device)
psi_1 = th.tensor(psi_1, dtype=th.float32, device=cfg.device)
offset_1 = th.tensor(offset_1, dtype=th.float32, device=cfg.device)
task_map = setup.task_map
task_idx = [0] * len(task_map)
for k, v in task_map.items():
task_idx[v] = int(k)
dscale = agent._gamma ** cfg.horizon
ctrl_cost = (
cfg.horizon
* cfg.env.args.ctrl_cost
* 0.25
* setup.envs.action_space.shape[0]
)
n = 1024
cperf: Dict[str, Dict[str, float]] = {'q': {}, 'r': {}}
starts = th.where(buffer._b['start_state'] == True)[0]
for d in setup.goal_dims.keys():
# Query start states from replay buffer
idx = th.randint(low=0, high=starts.shape[0], size=(n,))
obs = buffer._b['obs_observation'][starts[idx]].to(cfg.device)
# Sample goals and project to input space
# XXX assumes we train with backprojecting goals
feats = list(map(int, d.replace('+', ',').split(',')))
if len(feats) > 1 and cfg.estimate_joint_spaces == 'gmm':
sidx = th.randint(low=0, high=buffer.size, size=(n * 10,))
sample = (
th.bmm(
buffer._b['gs_observation'][sidx].unsqueeze(1),
psi[feats]
.T.unsqueeze(0)
.expand(sidx.shape[0], gsdim, len(feats)),
).squeeze(1)
+ offset[feats]
)
clf = GaussianMixture(
n_components=32, max_iter=100, n_init=10, covariance_type='full'
)
clf.fit(sample.cpu())
wgoal = th.tensor(
clf.sample(n)[0].clip(-1, 1),
device=obs.device,
dtype=th.float32,
)
elif len(feats) > 1 and cfg.estimate_joint_spaces == 'kmeans':
sidx = th.randint(low=0, high=buffer.size, size=(n * 10,))
sample = (
th.bmm(
buffer._b['gs_observation'][sidx].unsqueeze(1),
psi[feats]
.T.unsqueeze(0)
.expand(sidx.shape[0], gsdim, len(feats)),
).squeeze(1)
+ offset[feats]
)
clf = KMeans(n_clusters=n)
clf.fit(sample.cpu())
wgoal = th.tensor(
clf.cluster_centers_.clip(-1, 1),
device=obs.device,
dtype=th.float32,
)
else:
wgoal = th.rand(size=(n, len(feats)), device=obs.device) * 2 - 1
gsobs = buffer._b['gs_observation'][starts[idx]].to(cfg.device)
ws = (
th.bmm(
gsobs.unsqueeze(1),
psi[feats].T.unsqueeze(0).expand(n, gsdim, len(feats)),
).squeeze(1)
+ offset[feats]
)
for i, f in enumerate(feats):
if f in delta_feats:
wgoal[:, i] += ws[:, i]
s = gsobs[:, task_idx]
gb = (
th.bmm(
wgoal.unsqueeze(1),
psi_1[feats].unsqueeze(0).expand(n, len(feats), gsdim),
).squeeze(1)
+ offset_1
)[:, task_idx]
feature_mask = th.zeros(len(task_idx), device=obs.device)
for f in d.replace('+', ',').split(','):
feature_mask[setup.task_map[f]] = 1
goal = (gb - s) * feature_mask
# Record distances in goal space
wobs = (
th.bmm(
gsobs.unsqueeze(1),
psi[feats].T.unsqueeze(0).expand(n, gsdim, len(feats)),
).squeeze(1)
+ offset[feats]
)
dist = th.linalg.norm(wgoal - wobs, ord=2, dim=1)
# Finally, the bow task input
task = th.zeros(len(task_map), device=cfg.device)
for f in feats:
task[task_map[str(f)]] = 1
task = task.unsqueeze(0).expand(n, len(task_map))
# Query mean action and corresponding Q-value
with th.no_grad():
action = model.pi(
{
'observation': obs,
'task': task,
'desired_goal': goal,
}
).mean
q = model.q(
{
'observation': obs,
'task': task,
'desired_goal': goal,
'action': action,
}
)
q1 = q[:, 0]
q2 = q[:, 1]
r = model.reachability(
{
'observation': obs,
'task': task,
'desired_goal': goal,
'action': action,
}
)
q = th.min(q1, q2).view(-1)
cperf['q'][d] = (q >= (dist - ctrl_cost) * dscale).sum().item() / n
cperf['r'][d] = r.clamp(0, 1).mean().item()
return cperf