lerobot/common/utils/wandb_utils.py (108 lines of code) (raw):

#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import re from glob import glob from pathlib import Path from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from termcolor import colored from lerobot.common.constants import PRETRAINED_MODEL_DIR from lerobot.configs.train import TrainPipelineConfig def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: """Return a group name for logging. Optionally returns group name as list.""" lst = [ f"policy:{cfg.policy.type}", f"seed:{cfg.seed}", ] if cfg.dataset is not None: lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") return lst if return_list else "-".join(lst) def get_wandb_run_id_from_filesystem(log_dir: Path) -> str: # Get the WandB run ID. paths = glob(str(log_dir / "wandb/latest-run/run-*")) if len(paths) != 1: raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1]) if match is None: raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") wandb_run_id = match.groups(0)[0] return wandb_run_id def get_safe_wandb_artifact_name(name: str): """WandB artifacts don't accept ":" or "/" in their name.""" return name.replace(":", "_").replace("/", "_") class WandBLogger: """A helper class to log object using wandb.""" def __init__(self, cfg: TrainPipelineConfig): self.cfg = cfg.wandb self.log_dir = cfg.output_dir self.job_name = cfg.job_name self.env_fps = cfg.env.fps if cfg.env else None self._group = cfg_to_group(cfg) # Set up WandB. os.environ["WANDB_SILENT"] = "True" import wandb wandb_run_id = ( cfg.wandb.run_id if cfg.wandb.run_id else get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None ) wandb.init( id=wandb_run_id, project=self.cfg.project, entity=self.cfg.entity, name=self.job_name, notes=self.cfg.notes, tags=cfg_to_group(cfg, return_list=True), dir=self.log_dir, config=cfg.to_dict(), # TODO(rcadene): try set to True save_code=False, # TODO(rcadene): split train and eval, and run async eval with job_type="eval" job_type="train_eval", resume="must" if cfg.resume else None, mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", ) run_id = wandb.run.id # NOTE: We will override the cfg.wandb.run_id with the wandb run id. # This is because we want to be able to resume the run from the wandb run id. cfg.wandb.run_id = run_id # Handle custom step key for rl asynchronous training. self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb def log_policy(self, checkpoint_dir: Path): """Checkpoints the policy to wandb.""" if self.cfg.disable_artifact: return step_id = checkpoint_dir.name artifact_name = f"{self._group}-{step_id}" artifact_name = get_safe_wandb_artifact_name(artifact_name) artifact = self._wandb.Artifact(artifact_name, type="model") artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) def log_dict( self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None ): if mode not in {"train", "eval"}: raise ValueError(mode) if step is None and custom_step_key is None: raise ValueError("Either step or custom_step_key must be provided.") # NOTE: This is not simple. Wandb step must always monotonically increase and it # increases with each wandb.log call, but in the case of asynchronous RL for example, # multiple time steps is possible. For example, the interaction step with the environment, # the training step, the evaluation step, etc. So we need to define a custom step key # to log the correct step for each metric. if custom_step_key is not None: if self._wandb_custom_step_key is None: self._wandb_custom_step_key = set() new_custom_key = f"{mode}/{custom_step_key}" if new_custom_key not in self._wandb_custom_step_key: self._wandb_custom_step_key.add(new_custom_key) self._wandb.define_metric(new_custom_key, hidden=True) for k, v in d.items(): if not isinstance(v, (int, float, str)): logging.warning( f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' ) continue # Do not log the custom step key itself. if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key: continue if custom_step_key is not None: value_custom_step = d[custom_step_key] data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step} self._wandb.log(data) continue self._wandb.log(data={f"{mode}/{k}": v}, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): if mode not in {"train", "eval"}: raise ValueError(mode) wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4") self._wandb.log({f"{mode}/video": wandb_video}, step=step)