docker_images/diffusers/app/healthchecks.py (175 lines of code) (raw):
"""
This file allows users to spawn some side service helping with giving a better view on the main ASGI app status.
The issue with the status route of the main application is that it gets unresponsive as soon as all workers get busy.
Thus, you cannot really use the said route as a healthcheck to decide whether your app is healthy or not.
Instead this module allows you to distinguish between a dead service (not able to even tcp connect to app port)
and a busy one (able to connect but not to process a trivial http request in time) as both states should result in
different actions (restarting the service vs scaling it). It also exposes some data to be
consumed as custom metrics, for example to be used in autoscaling decisions.
"""
import asyncio
import functools
import logging
import os
from collections import namedtuple
from typing import Optional
import aiohttp
import psutil
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
logger = logging.getLogger(__name__)
METRICS = ""
STATUS_OK = 0
STATUS_BUSY = 1
STATUS_ERROR = 2
def metrics():
logging.debug("Requesting metrics")
return METRICS
async def metrics_route(_request: Request) -> Response:
return Response(content=metrics())
routes = [
Route("/{whatever:path}", metrics_route),
]
app = Starlette(routes=routes)
def reset_logging():
if os.environ.get("METRICS_DEBUG", "false").lower() in ["1", "true"]:
level = logging.DEBUG
else:
level = logging.INFO
logging.basicConfig(
level=level,
format="healthchecks - %(asctime)s - %(levelname)s - %(message)s",
force=True,
)
@app.on_event("startup")
async def startup_event():
reset_logging()
# Link between `api-inference-community` and framework code.
asyncio.create_task(compute_metrics_loop(), name="compute_metrics")
@functools.lru_cache()
def get_listening_port():
logger.debug("Get listening port")
main_app_port = os.environ.get("MAIN_APP_PORT", "80")
try:
main_app_port = int(main_app_port)
except ValueError:
logger.warning(
"Main app port cannot be converted to an int, skipping and defaulting to 80"
)
main_app_port = 80
return main_app_port
async def find_app_process(
listening_port: int,
) -> Optional[namedtuple("addr", ["ip", "port"])]: # noqa
connections = psutil.net_connections()
app_laddr = None
for c in connections:
if c.laddr.port != listening_port:
logger.debug("Skipping listening connection bound to excluded port %s", c)
continue
if c.status == psutil.CONN_LISTEN:
logger.debug("Found LISTEN conn %s", c)
candidate = c.pid
try:
p = psutil.Process(candidate)
except psutil.NoSuchProcess:
continue
if p.name() == "gunicorn":
logger.debug("Found gunicorn process %s", p)
app_laddr = c.laddr
break
return app_laddr
def count_current_conns(app_port: int) -> str:
estab = []
conns = psutil.net_connections()
# logger.debug("Connections %s", conns)
for c in conns:
if c.status != psutil.CONN_ESTABLISHED:
continue
if c.laddr.port == app_port:
estab.append(c)
current_conns = len(estab)
logger.info("Current count of established connections to app: %d", current_conns)
curr_conns_str = """# HELP inference_app_established_conns Established connection count for a given app.
# TYPE inference_app_established_conns gauge
inference_app_established_conns{{port="{:d}"}} {:d}
""".format(
app_port, current_conns
)
return curr_conns_str
async def status_with_timeout(
listening_port: int, app_laddr: Optional[namedtuple("addr", ["ip", "port"])] # noqa
) -> str:
logger.debug("Checking application status")
status = STATUS_OK
if not app_laddr:
status = STATUS_ERROR
else:
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=0.5)
) as session:
url = "http://{}:{:d}/".format(app_laddr.ip, app_laddr.port)
async with session.get(url) as resp:
status_code = resp.status
status_text = await resp.text()
logger.debug("Status code %s and text %s", status_code, status_text)
if status_code != 200 or status_text != '{"ok":"ok"}':
status = STATUS_ERROR
except asyncio.TimeoutError:
logger.debug("Asgi app seems busy, unable to reach it before timeout")
status = STATUS_BUSY
except Exception as e:
logger.exception(e)
status = STATUS_ERROR
status_str = """# HELP inference_app_status Application health status (0: ok, 1: busy, 2: error).
# TYPE inference_app_status gauge
inference_app_status{{port="{:d}"}} {:d}
""".format(
listening_port, status
)
return status_str
async def single_metrics_compute():
global METRICS
listening_port = get_listening_port()
app_laddr = await find_app_process(listening_port)
current_conns = count_current_conns(listening_port)
status = await status_with_timeout(listening_port, app_laddr)
# Assignment is atomic, we should be safe without locking
METRICS = current_conns + status
# Persist metrics to the local ephemeral as well
metrics_file = os.environ.get("METRICS_FILE")
if metrics_file:
with open(metrics_file) as f:
f.write(METRICS)
@functools.lru_cache()
def get_polling_sleep():
logger.debug("Get polling sleep interval")
sleep_value = os.environ.get("METRICS_POLLING_INTERVAL", 10)
try:
sleep_value = float(sleep_value)
except ValueError:
logger.warning(
"Unable to cast METRICS_POLLING_INTERVAL env value %s to float. Defaulting to 10.",
sleep_value,
)
sleep_value = 10.0
return sleep_value
@functools.lru_cache()
def get_initial_delay():
logger.debug("Get polling initial delay")
sleep_value = os.environ.get("METRICS_INITIAL_DELAY", 30)
try:
sleep_value = float(sleep_value)
except ValueError:
logger.warning(
"Unable to cast METRICS_INITIAL_DELAY env value %s to float. "
"Defaulting to 30.",
sleep_value,
)
sleep_value = 30.0
return sleep_value
async def compute_metrics_loop():
initial_delay = get_initial_delay()
await asyncio.sleep(initial_delay)
polling_sleep = get_polling_sleep()
while True:
await asyncio.sleep(polling_sleep)
try:
await single_metrics_compute()
except Exception as e:
logger.error("Something wrong occurred while computing metrics")
logger.exception(e)
if __name__ == "__main__":
reset_logging()
try:
single_metrics_compute()
logger.info("Metrics %s", metrics())
except Exception as exc:
logging.exception(exc)
raise