docker_images/setfit/app/main.py (66 lines of code) (raw):

import functools import logging import os import pathlib from typing import Dict, Type from api_inference_community import hub from api_inference_community.routes import pipeline_route, status_ok from app.pipelines import Pipeline, TextClassificationPipeline from huggingface_hub import constants from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.routing import Route TASK = os.getenv("TASK") def get_model_id(): m_id = os.getenv("MODEL_ID") # Workaround, when sentence_transformers handles properly this env variable # this should not be needed anymore if constants.HF_HUB_OFFLINE: cache_dir = pathlib.Path(constants.HF_HUB_CACHE) m_id = hub.cached_revision_path( cache_dir=cache_dir, repo_id=m_id, revision=os.getenv("REVISION") ) return m_id MODEL_ID = get_model_id() logger = logging.getLogger(__name__) # Add the allowed tasks # Supported tasks are: # - text-generation # - text-classification # - token-classification # - translation # - summarization # - automatic-speech-recognition # - ... # For instance # from app.pipelines import AutomaticSpeechRecognitionPipeline # ALLOWED_TASKS = {"automatic-speech-recognition": AutomaticSpeechRecognitionPipeline} # You can check the requirements and expectations of each pipelines in their respective # directories. Implement directly within the directories. ALLOWED_TASKS: Dict[str, Type[Pipeline]] = { "text-classification": TextClassificationPipeline, } @functools.lru_cache() def get_pipeline() -> Pipeline: task = os.environ["TASK"] model_id = MODEL_ID if task not in ALLOWED_TASKS: raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}") return ALLOWED_TASKS[task](model_id) routes = [ Route("/{whatever:path}", status_ok), Route("/{whatever:path}", pipeline_route, methods=["POST"]), ] middleware = [Middleware(GZipMiddleware, minimum_size=1000)] if os.environ.get("DEBUG", "") == "1": from starlette.middleware.cors import CORSMiddleware middleware.append( Middleware( CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"], ) ) app = Starlette(routes=routes, middleware=middleware) @app.on_event("startup") async def startup_event(): logger = logging.getLogger("uvicorn.access") handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) logger.handlers = [handler] # Link between `api-inference-community` and framework code. app.get_pipeline = get_pipeline try: get_pipeline() except Exception: # We can fail so we can show exception later. pass if __name__ == "__main__": try: get_pipeline() except Exception: # We can fail so we can show exception later. pass