dataflux_pytorch/lightning/gcs_filesystem.py (109 lines of code) (raw):

import io import os from contextlib import contextmanager from pathlib import Path from typing import Generator, Optional, Union import torch.distributed as dist from dataflux_core import user_agent from google.cloud import storage from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter from dataflux_pytorch.dataflux_checkpoint import DatafluxCheckpointBuffer from dataflux_pytorch.lightning.path_utils import parse_gcs_path class GCSFileSystem(): def __init__( self, project_name: str, debug: bool, storage_client: Optional[storage.Client] = None, ): self.project_name = project_name self.storage_client = storage_client if not storage_client: self.storage_client = storage.Client(project=self.project_name) user_agent.add_dataflux_user_agent(self.storage_client) self.debug = debug @contextmanager def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]: bucket, path = parse_gcs_path(path) blob = self.storage_client.bucket(bucket).blob(path) if mode == "wb": # write mode. if self.debug: print( f"Creating Stream, Write Mode: Rank: {dist.get_rank()} Bucket: {bucket} path: {path}" ) with DatafluxCheckpointBuffer(blob) as stream: yield stream elif mode == "rb": # read mode. if self.debug: print( f"Creating Stream, Read Mode: Rank: {dist.get_rank()} Bucket: {bucket} path: {path}" ) stream = io.BytesIO() blob.download_to_file(stream) stream.seek(0) yield stream else: raise ValueError( "Invalid mode argument, create_stream only supports rb (read mode) & wb (write mode)" ) def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]: if not isinstance(path, Path): path = Path(path) return path / suffix def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: if not isinstance(path, Path): path = Path(path) return path def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None: old_bucket, old_path = parse_gcs_path(path) new_bucket, new_path = parse_gcs_path(new_path) if old_bucket != new_bucket: raise Exception( f"When renaming objects, the old bucket name (got: {old_bucket}) must be the same as the new bucket name (got: {new_bucket})" ) blob = self.storage_client.bucket(old_bucket).blob(old_path) self.storage_client.bucket(new_bucket).rename_blob(blob, new_path) def mkdir(self, path: Union[str, os.PathLike]) -> None: pass def exists(self, path: Union[str, os.PathLike]) -> bool: bucket, path = parse_gcs_path(path) blob = self.storage_client.bucket(bucket).blob(path) return blob.exists() def rm_file(self, path: Union[str, os.PathLike]) -> None: bucket, path = parse_gcs_path(path) blob = self.storage_client.bucket(bucket).blob(path) blob.delete() @classmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: if isinstance(checkpoint_id, Path): return True # parse_gcs_path will raise exception if path is not in valid. parse_gcs_path(checkpoint_id) return True class GCSDistributedWriter(FileSystemWriter): def __init__(self, path, project_name: str, storage_client: Optional[storage.Client] = None, debug: Optional[bool] = False, **kwargs): super().__init__(path, **kwargs) self.fs = GCSFileSystem(project_name=project_name, storage_client=storage_client, debug=debug) self.sync_files = False class GCSDistributedReader(FileSystemReader): def __init__(self, path: Union[str, os.PathLike], project_name: str, storage_client: Optional[storage.Client] = None, debug: Optional[bool] = False, **kwargs): super().__init__(path, **kwargs) self.fs = GCSFileSystem(project_name=project_name, storage_client=storage_client, debug=debug)