def parse_ckpt_path()

in src/nanotron/serialize/main.py [0:0]


def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]:
    """Parse checkpoint path from config and download checkpoint from S3 if needed.

    Args:
        config: Config object.

    Returns:
        Path to checkpoint or None if no checkpoint.
    """
    load_from_candidate = config.checkpoints.resume_checkpoint_path
    if load_from_candidate is not None:
        if check_path_is_local(load_from_candidate):
            latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt"
            if latest_meta_path.exists():
                with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi:
                    # TODO @thomasw21: make a better structure system so that we get typing correct
                    load_from_candidate = int(fi.read())
                checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate)

            elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists():
                # we assume that the checkpoint path is a path to a checkpoint
                checkpoint_path = config.checkpoints.resume_checkpoint_path

            else:
                log_rank(
                    f"No previous checkpoint found in: {latest_meta_path}",
                    logger=logger,
                    level=logging.INFO,
                    rank=0,
                )
                return None

            log_rank(
                f"Loading checkpoint from {checkpoint_path}",
                logger=logger,
                level=logging.INFO,
                rank=0,
            )
        else:
            latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt"
            if latest_meta_path.exists():
                # if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint
                with fs_open(latest_meta_path, mode="r") as fi:
                    latest_iteration = int(fi.read())
                s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration)  # load_path
                checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration)  # save_path
            elif config.checkpoints.resume_checkpoint_path.exists():
                # we assume that the checkpoint path is a path to a checkpoint
                s3_path = config.checkpoints.resume_checkpoint_path  # load_path
                checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name  # save_path
            else:
                log_rank(
                    f"No previous checkpoint found in: {config.checkpoints.resume_checkpoint_path}\n Initializing from scratch.",
                    logger=logger,
                    level=logging.WARNING,
                    rank=0,
                )
                return None
            log_rank(
                f"Downloading checkpoint from S3 in {checkpoint_path} ",
                logger=logger,
                level=logging.WARNING,
                rank=0,
            )
            # Download checkpoint from S3
            s3_mover = S3Mover(
                local_path=os.path.join(checkpoint_path),
                s3_path=os.path.join(s3_path),
                s5cmd_numworkers=config.s3_upload.s5cmd_numworkers,
                s5cmd_concurrency=config.s3_upload.s5cmd_concurrency,
                s5cmd_path=config.s3_upload.s5cmd_path,
                dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0),
            )
            s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
            s3_mover.start_downloading()
            s3_mover.distributed_wait_for_completion(parallel_context.world_pg)

        return checkpoint_path