maga_transformer/tools/model_assistant_server.py (52 lines of code) (raw):

import uvicorn import logging import argparse import glob import importlib import os from fastapi import FastAPI from fastapi.routing import APIRouter from anyio.lowlevel import RunVar from anyio import CapacityLimiter from maga_transformer.config.uvicorn_config import UVICORN_LOGGING_CONFIG from maga_transformer.config.log_config import LOGGING_CONFIG MAX_INCOMPLETE_EVENT_SIZE = 1024 * 1024 class ModelAssistantServer(object): def __init__(self, server_port): self._server_port = server_port def start(self): app = self.create_app() uvicorn.run(app, host="0.0.0.0", port=self._server_port, log_config=UVICORN_LOGGING_CONFIG, h11_max_incomplete_event_size=MAX_INCOMPLETE_EVENT_SIZE) def create_app(self): app = FastAPI() @app.on_event("startup") async def startup(): RunVar("_default_thread_limiter").set(CapacityLimiter(40)) api_router = ModelAssistantServer.include_api_from_subpath("api") # add APIRouter to FastAPI app app.include_router(api_router) return app @staticmethod def include_api_from_subpath(sub_path: str): router = APIRouter() # find all submodule current_path = os.path.dirname(os.path.abspath(__file__)) py_files = glob.glob(os.path.join(current_path, sub_path, "*.py"), recursive=True) pyc_files = glob.glob(os.path.join(current_path, sub_path, "*.pyc"), recursive=True) all_files = py_files + pyc_files for filename in all_files: if os.path.isfile(filename): # convert path to module name module_name = "maga_transformer.tools." + filename.replace(current_path, "").replace('/', '.')[:-3].strip(".") # add module module = importlib.import_module(module_name) if hasattr(module, "router"): # add router logging.info(f"register module:{module_name}, {filename}") router.include_router(module.router) return router def main(): parser = argparse.ArgumentParser() parser.add_argument('--port', '-p', type=int, required=False, default = 8088, help='service port') args = parser.parse_args() server = ModelAssistantServer(args.port) server.start() if __name__ == '__main__': logging.config.dictConfig(LOGGING_CONFIG) main()