in src/nanotron/serialize/main.py [0:0]
def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]:
"""Parse checkpoint path from config and download checkpoint from S3 if needed.
Args:
config: Config object.
Returns:
Path to checkpoint or None if no checkpoint.
"""
load_from_candidate = config.checkpoints.resume_checkpoint_path
if load_from_candidate is not None:
if check_path_is_local(load_from_candidate):
latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi:
# TODO @thomasw21: make a better structure system so that we get typing correct
load_from_candidate = int(fi.read())
checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate)
elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists():
# we assume that the checkpoint path is a path to a checkpoint
checkpoint_path = config.checkpoints.resume_checkpoint_path
else:
log_rank(
f"No previous checkpoint found in: {latest_meta_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
return None
log_rank(
f"Loading checkpoint from {checkpoint_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
else:
latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
# if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint
with fs_open(latest_meta_path, mode="r") as fi:
latest_iteration = int(fi.read())
s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path
checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path
elif config.checkpoints.resume_checkpoint_path.exists():
# we assume that the checkpoint path is a path to a checkpoint
s3_path = config.checkpoints.resume_checkpoint_path # load_path
checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path
else:
log_rank(
f"No previous checkpoint found in: {config.checkpoints.resume_checkpoint_path}\n Initializing from scratch.",
logger=logger,
level=logging.WARNING,
rank=0,
)
return None
log_rank(
f"Downloading checkpoint from S3 in {checkpoint_path} ",
logger=logger,
level=logging.WARNING,
rank=0,
)
# Download checkpoint from S3
s3_mover = S3Mover(
local_path=os.path.join(checkpoint_path),
s3_path=os.path.join(s3_path),
s5cmd_numworkers=config.s3_upload.s5cmd_numworkers,
s5cmd_concurrency=config.s3_upload.s5cmd_concurrency,
s5cmd_path=config.s3_upload.s5cmd_path,
dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0),
)
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
s3_mover.start_downloading()
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
return checkpoint_path