oss-torch-connector/osstorchconnector/oss_checkpoint.py (32 lines of code) (raw):
from ._oss_bucket_iterable import parse_oss_uri
from ._oss_client import OssClient, DataObject
from ctypes import *
from typing import Any
class OssCheckpoint:
"""A checkpoint manager for OSS.
To read a checkpoint from OSS, users need to create an `DataObject`
by providing oss_uri of the checkpoint stored in OSS. Similarly, to save a
checkpoint to OSS, users need to create an `DataObject` by providing oss_uri.
`DataObject` can be passed to torch.load, and torch.save.
"""
def __init__(
self,
endpoint: str,
cred_path: str = "",
config_path: str = "",
cred_provider: Any = None,
):
if not endpoint:
raise ValueError("endpoint must be non-empty")
else:
self._endpoint = endpoint
if not cred_path:
self._cred_path = ""
else:
self._cred_path = cred_path
if not config_path:
self._config_path = ""
else:
self._config_path = config_path
self._cred_provider = cred_provider
self._client = OssClient(self._endpoint, self._cred_path, self._config_path, cred_provider=self._cred_provider)
def reader(self, oss_uri: str):
"""Creates an DataObject from a given oss_uri.
Args:
oss_uri (str): A valid oss_uri. (i.e. oss://<BUCKET>/<KEY>)
Returns:
DataObject: a read-only binary stream of the OSS object's contents, specified by the oss_uri.
"""
bucket, key = parse_oss_uri(oss_uri)
return self._client.get_object(bucket, key, type=1)
def writer(self, oss_uri: str) -> DataObject:
"""Creates an DataObject from a given oss_uri.
Args:
oss_uri (str): A valid oss_uri. (i.e. oss://<BUCKET>/<KEY>)
Returns:
DataObject: a write-only binary stream. The content is saved to OSS using the specified oss_uri.
"""
bucket, key = parse_oss_uri(oss_uri)
return self._client.put_object(bucket, key)