metaflow/plugins/cards/card_server.py (325 lines of code) (raw):
import os
import json
from http.server import BaseHTTPRequestHandler
from threading import Thread
from multiprocessing import Pipe
from multiprocessing.connection import Connection
from urllib.parse import urlparse
import time
try:
from http.server import ThreadingHTTPServer
except ImportError:
from socketserver import ThreadingMixIn
from http.server import HTTPServer
class ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
daemon_threads = True
from .card_client import CardContainer
from .exception import CardNotPresentException
from .card_resolver import resolve_paths_from_task
from metaflow import namespace
from metaflow.exception import MetaflowNotFound
from metaflow.plugins.datastores.local_storage import LocalStorage
VIEWER_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "card_viewer", "viewer.html"
)
CARD_VIEWER_HTML = open(VIEWER_PATH).read()
TASK_CACHE = {}
_ClickLogger = None
class RunWatcher(Thread):
"""
A thread that watches for new runs and sends the run_id to the
card server when a new run is detected. It observes the `latest_run`
file in the `.metaflow/<flowname>` directory.
"""
def __init__(self, flow_name, connection: Connection):
super().__init__()
self.daemon = True
self._connection = connection
self._flow_name = flow_name
self._watch_file = self._initialize_watch_file()
if self._watch_file is None:
_ClickLogger(
"Warning: Could not initialize watch file location.", fg="yellow"
)
self._current_run_id = self.get_run_id()
def _initialize_watch_file(self):
local_root = LocalStorage.datastore_root
if local_root is None:
local_root = LocalStorage.get_datastore_root_from_config(
lambda _: None, create_on_absent=False
)
return (
os.path.join(local_root, self._flow_name, "latest_run")
if local_root
else None
)
def get_run_id(self):
# Try to reinitialize watch file if needed
if not self._watch_file:
self._watch_file = self._initialize_watch_file()
# Early return if watch file is still None or doesn't exist
if not (self._watch_file and os.path.exists(self._watch_file)):
return None
try:
with open(self._watch_file, "r") as f:
return f.read().strip()
except (IOError, OSError) as e:
_ClickLogger(
"Warning: Could not read run ID from watch file: %s" % e, fg="yellow"
)
return None
def watch(self):
while True:
run_id = self.get_run_id()
if run_id != self._current_run_id:
self._current_run_id = run_id
self._connection.send(run_id)
time.sleep(2)
def run(self):
self.watch()
class CardServerOptions:
def __init__(
self,
flow_name,
run_object,
only_running,
follow_resumed,
flow_datastore,
follow_new_runs,
max_cards=20,
poll_interval=5,
):
from metaflow import Run
self.RunClass = Run
self.run_object = run_object
self.flow_name = flow_name
self.only_running = only_running
self.follow_resumed = follow_resumed
self.flow_datastore = flow_datastore
self.max_cards = max_cards
self.follow_new_runs = follow_new_runs
self.poll_interval = poll_interval
self._parent_conn, self._child_conn = Pipe()
def refresh_run(self):
if not self.follow_new_runs:
return False
if not self.parent_conn.poll():
return False
run_id = self.parent_conn.recv()
if run_id is None:
return False
namespace(None)
try:
self.run_object = self.RunClass(f"{self.flow_name}/{run_id}")
return True
except MetaflowNotFound:
return False
@property
def parent_conn(self):
return self._parent_conn
@property
def child_conn(self):
return self._child_conn
def cards_for_task(
flow_datastore, task_pathspec, card_type=None, card_hash=None, card_id=None
):
try:
paths, card_ds = resolve_paths_from_task(
flow_datastore,
task_pathspec,
type=card_type,
hash=card_hash,
card_id=card_id,
)
except CardNotPresentException:
return None
for card in CardContainer(paths, card_ds, origin_pathspec=None):
yield card
def cards_for_run(
flow_datastore,
run_object,
only_running,
card_type=None,
card_hash=None,
card_id=None,
max_cards=20,
):
curr_idx = 0
for step in run_object.steps():
for task in step.tasks():
if only_running and task.finished:
continue
card_generator = cards_for_task(
flow_datastore,
task.pathspec,
card_type=card_type,
card_hash=card_hash,
card_id=card_id,
)
if card_generator is None:
continue
for card in card_generator:
curr_idx += 1
if curr_idx >= max_cards:
raise StopIteration
yield task.pathspec, card
class CardViewerRoutes(BaseHTTPRequestHandler):
card_options: CardServerOptions = None
run_watcher: RunWatcher = None
def do_GET(self):
try:
_, path = self.path.split("/", 1)
try:
prefix, suffix = path.split("/", 1)
except:
prefix = path
suffix = None
except:
prefix = None
if prefix in self.ROUTES:
self.ROUTES[prefix](self, suffix)
else:
self._response(open(VIEWER_PATH).read().encode("utf-8"))
def get_runinfo(self, suffix):
run_id_changed = self.card_options.refresh_run()
if run_id_changed:
self.log_message(
"RunID changed in the background to %s"
% self.card_options.run_object.pathspec
)
_ClickLogger(
"RunID changed in the background to %s"
% self.card_options.run_object.pathspec,
fg="blue",
)
if self.card_options.run_object is None:
self._response(
{"status": "No Run Found", "flow": self.card_options.flow_name},
code=404,
is_json=True,
)
return
task_card_generator = cards_for_run(
self.card_options.flow_datastore,
self.card_options.run_object,
self.card_options.only_running,
max_cards=self.card_options.max_cards,
)
flow_name = self.card_options.run_object.parent.id
run_id = self.card_options.run_object.id
cards = []
for pathspec, card in task_card_generator:
step, task = pathspec.split("/")[-2:]
_task = self.card_options.run_object[step][task]
task_finished = True if _task.finished else False
cards.append(
dict(
task=pathspec,
label="%s/%s %s" % (step, task, card.hash),
card_object=dict(
hash=card.hash,
type=card.type,
path=card.path,
id=card.id,
),
finished=task_finished,
card="%s/%s" % (pathspec, card.hash),
)
)
resp = {
"status": "ok",
"flow": flow_name,
"run_id": run_id,
"cards": cards,
"poll_interval": self.card_options.poll_interval,
}
self._response(resp, is_json=True)
def get_card(self, suffix):
_suffix = urlparse(self.path).path
_, flow, run_id, step, task_id, card_hash = _suffix.strip("/").split("/")
pathspec = "/".join([flow, run_id, step, task_id])
cards = list(
cards_for_task(
self.card_options.flow_datastore, pathspec, card_hash=card_hash
)
)
if len(cards) == 0:
self._response({"status": "Card Not Found"}, code=404)
return
selected_card = cards[0]
self._response(selected_card.get().encode("utf-8"))
def get_data(self, suffix):
_suffix = urlparse(self.path).path
_, flow, run_id, step, task_id, card_hash = _suffix.strip("/").split("/")
pathspec = "/".join([flow, run_id, step, task_id])
cards = list(
cards_for_task(
self.card_options.flow_datastore, pathspec, card_hash=card_hash
)
)
if len(cards) == 0:
self._response(
{
"status": "Card Not Found",
},
is_json=True,
code=404,
)
return
status = "ok"
try:
task_object = self.card_options.run_object[step][task_id]
except KeyError:
return self._response(
{"status": "Task Not Found", "is_complete": False},
is_json=True,
code=404,
)
is_complete = task_object.finished
selected_card = cards[0]
card_data = selected_card.get_data()
if card_data is not None:
self.log_message(
"Task Success: %s, Task Finished: %s"
% (task_object.successful, is_complete)
)
if not task_object.successful and is_complete:
status = "Task Failed"
self._response(
{"status": status, "payload": card_data, "is_complete": is_complete},
is_json=True,
)
else:
self._response(
{"status": "ok", "is_complete": is_complete},
is_json=True,
code=404,
)
def _response(self, body, is_json=False, code=200):
self.send_response(code)
mime = "application/json" if is_json else "text/html"
self.send_header("Content-type", mime)
self.end_headers()
if is_json:
self.wfile.write(json.dumps(body).encode("utf-8"))
else:
self.wfile.write(body)
ROUTES = {"runinfo": get_runinfo, "card": get_card, "data": get_data}
def _is_debug_mode():
debug_flag = os.environ.get("METAFLOW_DEBUG_CARD_SERVER")
if debug_flag is None:
return False
return debug_flag.lower() in ["true", "1"]
def create_card_server(card_options: CardServerOptions, port, ctx_obj):
CardViewerRoutes.card_options = card_options
global _ClickLogger
_ClickLogger = ctx_obj.echo
if card_options.follow_new_runs:
CardViewerRoutes.run_watcher = RunWatcher(
card_options.flow_name, card_options.child_conn
)
CardViewerRoutes.run_watcher.start()
server_addr = ("", port)
ctx_obj.echo(
"Starting card server on port %d " % (port),
fg="green",
bold=True,
)
# Disable logging if not in debug mode
if not _is_debug_mode():
CardViewerRoutes.log_request = lambda *args, **kwargs: None
CardViewerRoutes.log_message = lambda *args, **kwargs: None
server = ThreadingHTTPServer(server_addr, CardViewerRoutes)
server.serve_forever()