pai/tensorboard.py (125 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import posixpath
import time
from typing import Optional
import requests
from .common.oss_utils import is_oss_uri
from .exception import UnexpectedStatusException
from .session import Session, get_default_session
class TensorBoardStatus(object):
Pending = "Pending"
Creating = "Creating"
Running = "Running"
Creating_Failed = "Creating_Failed"
Deleting = "Deleting"
Deleted = "Deleted"
Stopping = "Stopping"
Stopped = "Stopped"
@classmethod
def is_terminated(cls, status):
return status in [
cls.Creating_Failed,
cls.Stopped,
cls.Deleted,
]
@classmethod
def is_running(cls, status):
return status in [
cls.Running,
]
class TensorBoard(object):
def __init__(self, tensorboard_id: str, session: Optional[Session] = None):
self.session = session or get_default_session()
self.tensorboard_id = tensorboard_id
self._api_object = self.session.tensorboard_api.get(tensorboard_id)
def __repr__(self):
return "TensorBoard(tensorboard_id={}, name={}, status={})".format(
self.tensorboard_id,
self.display_name,
self._status(),
)
@property
def status(self):
self._refresh()
return self._status()
def _status(self):
return self._api_object.get("Status")
@property
def app_uri(self):
"""Get the TensorBoard application URI."""
return self._api_object.get("TensorboardUrl")
@property
def summary_uri(self):
return self._api_object.get("SummaryUri")
@property
def display_name(self):
return self._api_object.get("DisplayName")
def _refresh(self):
self._api_object = self.session.tensorboard_api.get(self.tensorboard_id)
@classmethod
def create(
cls,
uri: str,
wait: bool = True,
display_name: Optional[str] = None,
max_runtime_in_minutes: Optional[int] = None,
source_id: Optional[str] = None,
source_type: Optional[str] = None,
session: Optional[Session] = None,
) -> "TensorBoard":
"""Launch a TensorBoard Application.
Args:
uri (str): A OSS URI to the directory containing the TensorBoard logs.
wait (bool): Whether to wait for the TensorBoard application to be ready.
display_name (str, optional): Display name of the TensorBoard application.
Defaults to None.
max_runtime_in_minutes: Maximum running time in minutes.
source_type (str, optional): The type of the source object. Defaults to None.
source_id (str, optional): The ID of the source object. Defaults to None.
session: A Session object to use in interacting with PAI.
Returns:
TensorBoard: A TensorBoard object.
Examples:
Create a TensorBoard application from a OSS URI:
>>> from pai.tensorboard import TensorBoard
>>> tb = TensorBoard.create("oss://my-bucket/path/to/logs_dir/")
>>> # Get TensorBoard Application URL.
>>> print(tb.app_uri)
"""
session = session or get_default_session()
if not is_oss_uri(uri):
raise RuntimeError("Currently only support OSS uri to create TensorBoard.")
oss_uri = session.patch_oss_endpoint(uri)
data_source_type = "OSS"
if not display_name:
# Use the last part of the OSS URI as the display name.
display_name = posixpath.basename(uri.rstrip("/"))
if not display_name:
raise RuntimeError("Failed to infer display name from OSS URI.")
tb_id = session.tensorboard_api.create(
uri=oss_uri,
display_name=display_name,
data_source_type=data_source_type,
max_running_time_minutes=max_runtime_in_minutes,
# hack: summary_relative_path is required for CreateTensorBoard API.
summary_relative_path="/",
source_id=source_id,
source_type=source_type,
)
tensorboard = TensorBoard(tensorboard_id=tb_id, session=session)
if wait:
tensorboard.wait()
return tensorboard
def wait(self):
"""Wait for the TensorBoard application to be ready.
Raises:
UnExpectedStatusException: If the TensorBoard application is terminated
unexpectedly.
"""
while True:
status = self.status
if TensorBoardStatus.is_terminated(status):
raise UnexpectedStatusException(
"TensorBoard terminated unexpectedly in status: %s" % status,
status,
)
elif TensorBoardStatus.is_running(status):
self._wait_app_available()
return
else:
time.sleep(5)
self._refresh()
def start(self, wait: bool = True):
"""Start the TensorBoard application."""
self._refresh()
if TensorBoardStatus.is_running(self.status):
return
self.session.tensorboard_api.start(self.tensorboard_id)
if wait:
self.wait()
def stop(self):
"""Stop the TensorBoard application."""
self._refresh()
if not TensorBoardStatus.is_running(self.status):
return
self.session.tensorboard_api.stop(self.tensorboard_id)
def _wait_app_available(self):
"""Wait until the TensorBoard application is available."""
if not self.app_uri:
raise RuntimeError("TensorBoard application URL is not available.")
while True:
resp = requests.get(self.app_uri)
# Status code not equals 5xx means the TensorBoard application is available.
if resp.status_code // 100 != 5:
break
time.sleep(5)
def delete(self):
"""Delete the TensorBoard Application."""
self.session.tensorboard_api.delete(self.tensorboard_id)