oss-torch-connector/osstorchconnector/oss_checkpoint.py (32 lines of code) (raw):

from ._oss_bucket_iterable import parse_oss_uri from ._oss_client import OssClient, DataObject from ctypes import * from typing import Any class OssCheckpoint: """A checkpoint manager for OSS. To read a checkpoint from OSS, users need to create an `DataObject` by providing oss_uri of the checkpoint stored in OSS. Similarly, to save a checkpoint to OSS, users need to create an `DataObject` by providing oss_uri. `DataObject` can be passed to torch.load, and torch.save. """ def __init__( self, endpoint: str, cred_path: str = "", config_path: str = "", cred_provider: Any = None, ): if not endpoint: raise ValueError("endpoint must be non-empty") else: self._endpoint = endpoint if not cred_path: self._cred_path = "" else: self._cred_path = cred_path if not config_path: self._config_path = "" else: self._config_path = config_path self._cred_provider = cred_provider self._client = OssClient(self._endpoint, self._cred_path, self._config_path, cred_provider=self._cred_provider) def reader(self, oss_uri: str): """Creates an DataObject from a given oss_uri. Args: oss_uri (str): A valid oss_uri. (i.e. oss://<BUCKET>/<KEY>) Returns: DataObject: a read-only binary stream of the OSS object's contents, specified by the oss_uri. """ bucket, key = parse_oss_uri(oss_uri) return self._client.get_object(bucket, key, type=1) def writer(self, oss_uri: str) -> DataObject: """Creates an DataObject from a given oss_uri. Args: oss_uri (str): A valid oss_uri. (i.e. oss://<BUCKET>/<KEY>) Returns: DataObject: a write-only binary stream. The content is saved to OSS using the specified oss_uri. """ bucket, key = parse_oss_uri(oss_uri) return self._client.put_object(bucket, key)