api_inference_community/routes.py (221 lines of code) (raw):

import base64 import io import ipaddress import logging import os import time from typing import Any, Dict import psutil from api_inference_community.validation import ( AUDIO, AUDIO_INPUTS, IMAGE, IMAGE_INPUTS, IMAGE_OUTPUTS, KNOWN_TASKS, ffmpeg_convert, normalize_payload, parse_accept, ) from pydantic import ValidationError from starlette.requests import Request from starlette.responses import JSONResponse, Response HF_HEADER_COMPUTE_TIME = "x-compute-time" HF_HEADER_COMPUTE_TYPE = "x-compute-type" COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "cpu") logger = logging.getLogger(__name__) def already_left(request: Request) -> bool: """ Check if the caller has already left without waiting for the answer to come. This can help during burst to relieve the pressure on the worker by cancelling jobs whose results don't matter as they won't be fetched anyway :param request: :return: bool """ # NOTE: Starlette method request.is_disconnected is totally broken, consumes the payload, does not return # the correct status. So we use the good old way to identify if the caller is still there. # In any case, if we are not sure, we return False logger.info("Checking if request caller already left") try: client = request.client host = client.host if not host: return False port = int(client.port) host = ipaddress.ip_address(host) if port <= 0 or port > 65535: logger.warning("Unexpected source port format for caller %s", port) return False counter = 0 for connection in psutil.net_connections(kind="tcp"): counter += 1 if connection.status != "ESTABLISHED": continue if not connection.raddr: continue if int(connection.raddr.port) != port: continue if ( not connection.raddr.ip or ipaddress.ip_address(connection.raddr.ip) != host ): continue logger.info( "Found caller connection still established, caller is most likely still there, %s", connection, ) return False except Exception as e: logger.warning( "Unexpected error while checking if caller already left, assuming still there" ) logger.exception(e) return False logger.info( "%d connections checked. No connection found matching to the caller, probably left", counter, ) return True async def pipeline_route(request: Request) -> Response: start = time.time() task = os.environ["TASK"] # Shortcut: quickly check the task is in enum: no need to go any further otherwise, as we know for sure that # normalize_payload will fail below: this avoids us to wait for the pipeline to be loaded to return if task not in KNOWN_TASKS: msg = f"The task `{task}` is not recognized by api-inference-community" logger.error(msg) # Special case: despite the fact that the task comes from environment (which could be considered a service # config error, thus triggering a 500), this var indirectly comes from the user # so we choose to have a 400 here return JSONResponse({"error": msg}, status_code=400) if os.getenv("DISCARD_LEFT", "0").lower() in [ "1", "true", "yes", ] and already_left(request): logger.info("Discarding request as the caller already left") return Response(status_code=204) payload = await request.body() if os.getenv("DEBUG", "0") in {"1", "true"}: pipe = request.app.get_pipeline() try: pipe = request.app.get_pipeline() try: sampling_rate = pipe.sampling_rate except Exception: sampling_rate = None if task in AUDIO_INPUTS: msg = f"Sampling rate is expected for model for audio task {task}" logger.error(msg) return JSONResponse({"error": msg}, status_code=500) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) try: inputs, params = normalize_payload(payload, task, sampling_rate=sampling_rate) except ValidationError as e: errors = [] for error in e.errors(): if len(error["loc"]) > 0: errors.append( f'{error["msg"]}: received `{error["loc"][0]}` in `parameters`' ) else: errors.append( f'{error["msg"]}: received `{error["input"]}` in `parameters`' ) return JSONResponse({"error": errors}, status_code=400) except Exception as e: # We assume the payload is bad -> 400 logger.warning("Error while parsing input %s", e) return JSONResponse({"error": str(e)}, status_code=400) accept = request.headers.get("accept", "") lora_adapter = request.headers.get("lora") if lora_adapter: params["lora_adapter"] = lora_adapter return call_pipe(pipe, inputs, params, start, accept) def call_pipe(pipe: Any, inputs, params: Dict, start: float, accept: str) -> Response: root_logger = logging.getLogger() warnings = set() class RequestsHandler(logging.Handler): def emit(self, record): """Send the log records (created by loggers) to the appropriate destination. """ warnings.add(record.getMessage()) handler = RequestsHandler() handler.setLevel(logging.WARNING) root_logger.addHandler(handler) for _logger in logging.root.manager.loggerDict.values(): # type: ignore try: _logger.addHandler(handler) except Exception: pass status_code = 200 if os.getenv("DEBUG", "0") in {"1", "true"}: outputs = pipe(inputs, **params) try: outputs = pipe(inputs, **params) task = os.getenv("TASK") metrics = get_metric(inputs, task, pipe) except (AssertionError, ValueError, TypeError) as e: outputs = {"error": str(e)} status_code = 400 except Exception as e: outputs = {"error": "unknown error"} status_code = 500 logger.error(f"There was an inference error: {e}") logger.exception(e) if warnings and isinstance(outputs, dict): outputs["warnings"] = list(sorted(warnings)) compute_type = COMPUTE_TYPE headers = { HF_HEADER_COMPUTE_TIME: "{:.3f}".format(time.time() - start), HF_HEADER_COMPUTE_TYPE: compute_type, # https://stackoverflow.com/questions/43344819/reading-response-headers-with-fetch-api/44816592#44816592 "access-control-expose-headers": f"{HF_HEADER_COMPUTE_TYPE}, {HF_HEADER_COMPUTE_TIME}", } if status_code == 200: headers.update(**{k: str(v) for k, v in metrics.items()}) task = os.getenv("TASK") if task == "text-to-speech": waveform, sampling_rate = outputs audio_format = parse_accept(accept, AUDIO) data = ffmpeg_convert(waveform, sampling_rate, audio_format) headers["content-type"] = f"audio/{audio_format}" return Response(data, headers=headers, status_code=status_code) elif task == "audio-to-audio": waveforms, sampling_rate, labels = outputs items = [] headers["content-type"] = "application/json" audio_format = parse_accept(accept, AUDIO) for waveform, label in zip(waveforms, labels): data = ffmpeg_convert(waveform, sampling_rate, audio_format) items.append( { "label": label, "blob": base64.b64encode(data).decode("utf-8"), "content-type": f"audio/{audio_format}", } ) return JSONResponse(items, headers=headers, status_code=status_code) elif task in IMAGE_OUTPUTS: image = outputs image_format = parse_accept(accept, IMAGE) buffer = io.BytesIO() image.save(buffer, format=image_format.upper()) buffer.seek(0) img_bytes = buffer.read() return Response( img_bytes, headers=headers, status_code=200, media_type=f"image/{image_format}", ) return JSONResponse( outputs, headers=headers, status_code=status_code, ) def get_metric(inputs, task, pipe): if task in AUDIO_INPUTS: return {"x-compute-audio-length": get_audio_length(inputs, pipe.sampling_rate)} elif task in IMAGE_INPUTS: return {"x-compute-images": 1} else: return {"x-compute-characters": get_input_characters(inputs)} def get_audio_length(inputs, sampling_rate: int) -> float: if isinstance(inputs, dict): # Should only apply for internal AsrLive length_in_s = inputs["raw"].shape[0] / inputs["sampling_rate"] else: length_in_s = inputs.shape[0] / sampling_rate return length_in_s def get_input_characters(inputs) -> int: if isinstance(inputs, str): return len(inputs) elif isinstance(inputs, (tuple, list)): return sum(get_input_characters(input_) for input_ in inputs) elif isinstance(inputs, dict): return sum(get_input_characters(input_) for input_ in inputs.values()) return 0 async def status_ok(request): return JSONResponse({"ok": "ok"})