in tzrec/utils/checkpoint_util.py [0:0]
def list_distcp_param(checkpoint_dir: str) -> List[str]:
"""List."""
meta_paths = []
if os.path.exists(os.path.join(checkpoint_dir, ".metadata")):
meta_paths.append(checkpoint_dir)
else:
if os.path.exists(os.path.join(checkpoint_dir, "model", ".metadata")):
meta_paths.append(os.path.join(checkpoint_dir, "model"))
if os.path.exists(os.path.join(checkpoint_dir, "optimizer", ".metadata")):
meta_paths.append(os.path.join(checkpoint_dir, "optimizer"))
if len(meta_paths) == 0:
raise RuntimeError(f"Can't find distribute checkpoint in {checkpoint_dir}")
param_names = []
for meta_path in meta_paths:
reader = FileSystemReader(path=meta_path)
meta = reader.read_metadata()
logger.info(f"Params in {meta_path}:")
for k, v in meta.state_dict_metadata.items():
if isinstance(v, TensorStorageMetadata):
param_names.append(k)
logger.info(f"{k}: {v.size}")
return param_names