integration_tests/gaudi/conftest.py (200 lines of code) (raw):
import asyncio
import contextlib
import os
import shlex
import subprocess
import sys
import threading
import time
from tempfile import TemporaryDirectory
import docker
import pytest
from docker.errors import NotFound
import logging
from test_embed import TEST_CONFIGS
import aiohttp
logging.basicConfig(
level=logging.INFO,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
stream=sys.stdout,
)
logger = logging.getLogger(__file__)
# Use the latest image from the local docker build
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tei_hpu")
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
if DOCKER_VOLUME is None:
logger.warning(
"DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
)
LOG_LEVEL = os.getenv("LOG_LEVEL", "info")
BASE_ENV = {
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"LOG_LEVEL": LOG_LEVEL,
"HABANA_VISIBLE_DEVICES": "all",
}
HABANA_RUN_ARGS = {
"runtime": "habana",
}
def stream_container_logs(container, test_name):
"""Stream container logs in a separate thread."""
try:
for log in container.logs(stream=True, follow=True):
print(
f"[TEI Server Logs - {test_name}] {log.decode('utf-8')}",
end="",
file=sys.stderr,
flush=True,
)
except Exception as e:
logger.error(f"Error streaming container logs: {str(e)}")
class LauncherHandle:
def __init__(self, port: int):
self.port = port
self.base_url = f"http://127.0.0.1:{port}"
async def generate(self, prompt: str):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/embed",
json={"inputs": prompt},
headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Request failed with status {response.status}: {error_text}")
return await response.json()
def _inner_health(self):
raise NotImplementedError
async def health(self, timeout: int = 60):
assert timeout > 0
start_time = time.time()
logger.info(f"Starting health check with timeout of {timeout}s")
for attempt in range(timeout):
if not self._inner_health():
logger.error("Launcher crashed during health check")
raise RuntimeError("Launcher crashed")
try:
# Try to make a request using generate
await self.generate("test")
elapsed = time.time() - start_time
logger.info(f"Health check passed after {elapsed:.1f}s")
return
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt == timeout - 1:
logger.error(f"Health check failed after {timeout}s: {str(e)}")
raise RuntimeError(f"Health check failed: {str(e)}")
if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt
logger.debug(f"Connection attempt {attempt}/{timeout} failed: {str(e)}")
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Unexpected error during health check: {str(e)}")
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
class ContainerLauncherHandle(LauncherHandle):
def __init__(self, docker_client, container_name, port: int):
super().__init__(port)
self.docker_client = docker_client
self.container_name = container_name
def _inner_health(self) -> bool:
try:
container = self.docker_client.containers.get(self.container_name)
status = container.status
if status not in ["running", "created"]:
logger.warning(f"Container status is {status}")
# Get container logs for debugging
logs = container.logs().decode("utf-8")
logger.debug(f"Container logs:\n{logs}")
return False
return True
except Exception as e:
logger.error(f"Error checking container health: {str(e)}")
return False
class ProcessLauncherHandle(LauncherHandle):
def __init__(self, process, port: int):
super(ProcessLauncherHandle, self).__init__(port)
self.process = process
def _inner_health(self) -> bool:
return self.process.poll() is None
@pytest.fixture(scope="module")
def data_volume():
tmpdir = TemporaryDirectory()
yield tmpdir.name
try:
# Cleanup the temporary directory using sudo as it contains root files created by the container
subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Error cleaning up temporary directory: {str(e)}")
@pytest.fixture(scope="function")
def gaudi_launcher(event_loop):
@contextlib.contextmanager
def docker_launcher(
model_id: str,
test_name: str,
):
logger.info(
f"Starting docker launcher for model {model_id} and test {test_name}"
)
port = 8080
client = docker.from_env()
container_name = f"tei-hpu-test-{test_name.replace('/', '-')}"
try:
container = client.containers.get(container_name)
logger.info(
f"Stopping existing container {container_name} for test {test_name}"
)
container.stop()
container.wait()
except NotFound:
pass
except Exception as e:
logger.error(f"Error handling existing container: {str(e)}")
tei_args = TEST_CONFIGS[test_name]["args"].copy()
# add model_id to tei args
tei_args.append("--model-id")
tei_args.append(model_id)
env = BASE_ENV.copy()
env["HF_TOKEN"] = os.getenv("HF_TOKEN")
# Add env config that is definied in the fixture parameter
if "env_config" in TEST_CONFIGS[test_name]:
env.update(TEST_CONFIGS[test_name]["env_config"].copy())
volumes = [f"{DOCKER_VOLUME}:/data"]
logger.debug(f"Using volume {volumes}")
try:
logger.info(f"Creating container with name {container_name}")
# Log equivalent docker run command for debugging, this is not actually executed
container = client.containers.run(
DOCKER_IMAGE,
command=tei_args,
name=container_name,
environment=env,
detach=True,
volumes=volumes,
ports={"80/tcp": port},
**HABANA_RUN_ARGS,
)
logger.info(f"Container {container_name} started successfully")
# Start log streaming in a background thread
log_thread = threading.Thread(
target=stream_container_logs,
args=(container, test_name),
daemon=True, # This ensures the thread will be killed when the main program exits
)
log_thread.start()
# Add a small delay to allow container to initialize
time.sleep(2)
# Check container status after creation
status = container.status
logger.debug(f"Initial container status: {status}")
if status not in ["running", "created"]:
logs = container.logs().decode("utf-8")
logger.error(f"Container failed to start properly. Logs:\n{logs}")
yield ContainerLauncherHandle(client, container.name, port)
except Exception as e:
logger.error(f"Error starting container: {str(e)}")
# Get full traceback for debugging
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
raise
finally:
try:
container = client.containers.get(container_name)
logger.info(f"Stopping container {container_name}")
container.stop()
container.wait()
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
container.remove()
logger.info(f"Container {container_name} removed successfully")
except NotFound:
pass
except Exception as e:
logger.warning(f"Error cleaning up container: {str(e)}")
return docker_launcher