in pretrain.py [0:0]
def setup_training_mfdim(cfg: DictConfig):
if not isinstance(cfg.feature_dims, str):
cfg.feature_dims = str(cfg.feature_dims)
gs = g_goal_spaces[cfg.features][cfg.robot]
n = len(gs['str'])
# Support some special names for convenience
if cfg.feature_dims == 'all':
dims = [str(i) for i in range(n)]
elif cfg.feature_dims == 'torso':
dims = [
str(i)
for i in range(n)
if gs['str'][i].startswith(':')
or gs['str'][i].startswith('torso:')
or gs['str'][i].startswith('root')
]
else:
try:
for d in cfg.feature_dims.split('#'):
_ = map(int, d.split('+'))
dims = [d for d in cfg.feature_dims.split('#')]
except:
dims = [
str(i)
for i in range(n)
if re.match(cfg.feature_dims, gs['str'][i]) is not None
]
uncontrollable = set()
for dim in dims:
for d in map(int, dim.split('+')):
if not CtrlgsPreTrainingEnv.feature_controllable(
cfg.robot, cfg.features, d
):
uncontrollable.add(dim)
log.warning(f'Removing uncontrollable feature {dim}')
break
cfg.feature_dims = '#'.join([d for d in dims if not d in uncontrollable])
if cfg.feature_rank == 'max':
cfg.feature_rank = len(cfg.feature_dims.split('#'))
if len(cfg.feature_dims) < int(cfg.feature_rank):
raise ValueError('Less features to control than the requested rank')
# Setup custom environment arguments based on the selected robot
prev_args: Dict[str, Any] = {}
if isinstance(cfg.env.args, DictConfig):
prev_args = dict(cfg.env.args)
cfg.env.args = {
**prev_args,
'robot': cfg.robot,
}
fdist = {
','.join(d): 1.0
for d in combinations(cfg.feature_dims.split('#'), cfg.feature_rank)
}
if cfg.task_weighting.startswith('lp'):
for k, v in fdist.items():
fdist[k] = v / len(fdist)
feats: Set[int] = set()
task_map: Dict[str, int] = {}
for fs in fdist.keys():
for f in map(int, fs.replace('+', ',').split(',')):
feats.add(f)
for f in sorted(feats):
task_map[str(f)] = len(task_map)
cfg.env.args = {
**cfg.env.args,
'feature_dist': fdist,
'task_map': task_map,
}
if cfg.agent.gamma == 'auto_horizon':
cfg.agent.gamma = 1 - 1 / cfg.horizon
log.info(f'gamma set to {cfg.agent.gamma}')
setup = setup_training(cfg)
if 'goal_dims' in cfg.env.args:
setup.goal_dims = dict(cfg.env.args.goal_dims)
else:
setup.goal_dims = dict(cfg.env.args.feature_dist)
setup.task_map = dict(cfg.env.args.get('task_map', {}))
return setup