in mtrl/logger.py [0:0]
def __init__(self, log_dir, config, retain_logs: bool = False):
self._log_dir = log_dir
self.config = config
if "metaworld" in self.config.env.name:
num_envs = int(
"".join(
[
x
for x in self.config.env.benchmark._target_.split(".")[1]
if x.isdigit()
]
)
)
else:
env_list: List[str] = []
for key in self.config.metrics:
if "_" in key:
mode, submode = key.split("_")
# todo: should we instead throw an error here?
if mode in self.config.env and submode in self.config.env[mode]:
env_list += self.config.env[mode][submode]
else:
if key in self.config.env:
env_list += self.config.env[key]
num_envs = len(set(env_list))
def _get_formatting(
current_formatting: List[List[str]],
) -> Dict[str, List[str]]:
formating: Dict[str, List[str]] = {
_format[0]: _format[1:] for _format in current_formatting
}
if num_envs > 0:
keys = list(formating.keys())
for key in keys:
if key.endswith("_"):
value = formating.pop(key)
for index in range(num_envs):
new_key = key + str(index)
if value[0] is None:
abbr = None
else:
abbr = value[0] + str(index)
formating[new_key] = [abbr, *value[1:]]
return formating
self.mgs = {
key: MetersGroup(
os.path.join(log_dir, f"{key}.log"),
formating=_get_formatting(current_formatting=value),
mode=key,
retain_logs=retain_logs,
)
for key, value in self.config.metrics.items()
}