run_tests/piston_client.py (161 lines of code) (raw):

import asyncio import os import random import re import subprocess from collections import Counter from functools import lru_cache import aiohttp class PistonError(Exception): pass @lru_cache(maxsize=1) def get_piston_client_from_env(session=None): piston_endpoints = os.getenv("PISTON_ENDPOINTS") if piston_endpoints is None: raise ValueError("For IOI problems Piston endpoints running our IOI package are required. Please add a list of valid Piston endpoints to a PISTON_ENDPOINTS varialbe in a `.env` file.") piston_endpoints = piston_endpoints.split(",") if piston_endpoints != "slurm" else get_slurm_piston_endpoints() random.shuffle(piston_endpoints) max_requests_per_endpoint = os.getenv("PISTON_MAX_REQUESTS_PER_ENDPOINT", "1") return PistonClient(piston_endpoints, session, max_requests_per_endpoint=int(max_requests_per_endpoint)) class PistonClient: """ A client that will automatically load balance across multiple Piston (https://github.com/engineer-man/piston) workers. This assumes piston is running our custom cms_ioi package: https://github.com/guipenedo/piston/releases/ We recommend starting the instances with the following script as otherwise some IOI problems will hit default limits: ``` export PISTON_COMPILE_TIMEOUT=60000 export PISTON_RUN_TIMEOUT=60000 export PISTON_OUTPUT_MAX_SIZE=1000000000 export PISTON_MAX_FILE_SIZE=1000000000 export PISTON_DISABLE_NETWORKING=true export PISTON_REPO_URL=https://github.com/guipenedo/piston/releases/download/pkgs/index mkdir /piston sed -i '/app.use(body_parser.urlencoded/c\ app.use(body_parser.urlencoded({ extended: true, limit: \"512mb\" }));' src/index.js sed -i '/app.use(body_parser.json/c\ app.use(body_parser.json({ limit: \"512mb\" }));' src/index.js # Start server in background node src``` Piston docs for API usage: https://piston.readthedocs.io/en/latest/api-v2/ """ def __init__(self, base_endpoint: str | list[str] = "http://ip-10-53-80-65:3223/api/v2", session=None, max_requests_per_endpoint=1): self.max_requests_per_endpoint = max_requests_per_endpoint self.base_endpoints = [base_endpoint] if isinstance(base_endpoint, str) else base_endpoint if len(self.base_endpoints) == 0: raise ValueError("No Piston endpoints provided. Please check your PISTON_ENDPOINTS environment variable.") self.endpoint_ids = {endpoint: i for i, endpoint in enumerate(self.base_endpoints)} self._session = session self.endpoint_tokens = asyncio.Queue(maxsize=max_requests_per_endpoint * len(self.base_endpoints)) for _ in range(max_requests_per_endpoint): for base_endpoint in self.base_endpoints: self.endpoint_tokens.put_nowait(base_endpoint) self._endpoint_failures = Counter() self._unhealthy_endpoints = set() self._endpoint_failures_lock = asyncio.Lock() @property def session(self): if self._session is None: self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(sock_read=10), connector=aiohttp.TCPConnector( limit=self.max_requests_per_endpoint * len(self.base_endpoints), ttl_dns_cache=300, keepalive_timeout=5 * 60 ) ) return self._session async def _wait_for_endpoint(self): endpoint = await self.endpoint_tokens.get() return endpoint async def _release_endpoint(self, endpoint): await self.endpoint_tokens.put(endpoint) async def _send_request(self, endpoint, route, data=None, method="post"): async with self.session.request(method, f"{endpoint.rstrip('/')}/{route}", json=data, headers={"Content-Type": "application/json"}) as response: return await response.json(content_type=None) async def _send_to_all(self, route, data=None, method="post"): return await asyncio.gather(*[self._send_request(endpoint, route, data, method) for endpoint in self.base_endpoints]) async def _send_to_one(self, endpoint, route, data=None, method="post"): return await self._send_request(endpoint, route, data, method) async def install_package(self, language, version): return await self._send_to_all("packages", { "language": language, "version": version }, method="post") async def uninstall_package(self, language, version): return await self._send_to_all("packages", { "language": language, "version": version }, method="delete") async def get_supported_runtimes(self): return await self._send_to_all("runtimes", method="get") async def execute(self, data) -> tuple[str, str]: """ Requests to the IOI package return the score as a float in the stdout, as well as optional feedback/errors in stderr. Returns a tuple of (score, feedback). """ response = await self._send_execute(data) if 'message' in response: raise PistonError(response['message']) if 'compile' in response and response['compile']['code'] != 0: return "0", "Compilation error exit code " + str(response['compile']['code']) + "\n" + response['compile']['stderr'] if 'run' not in response: raise PistonError(response) if response['run']['code'] == 1 and "MemoryError" in response['run']['stderr']: return "0", "Memory limit exceeded" # successful result if response['run']['stdout']: return response['run']['stdout'], response['run']['stderr'] if response['run']['signal'] == 'SIGKILL': return "0", "Time limit exceeded" # other issues if response['run']['code'] != 0: raise PistonError(f"language={response['language']}, version={response['version']}, exit code={response['run']['code']}, stderr={response['run']['stderr']}, signal={response['run']['signal']}") return '0', 'Unknown error' async def _check_failed_endpoint(self, endpoint): async with self._endpoint_failures_lock: if endpoint in self._unhealthy_endpoints: return try: await asyncio.sleep(5) await self.get_supported_runtimes() except Exception as e: print(f"Error checking endpoint {endpoint}, dropping it ({e})") self._unhealthy_endpoints.add(endpoint) if len(self._unhealthy_endpoints) >= len(self.base_endpoints): raise PistonError("All endpoints are unhealthy. Please check your Piston workers.") async def _send_execute(self, data): data = data | { "language": "cms_ioi", "version": "*", } max_retries = 5 base_delay = 1.0 status = None endpoint = None for attempt in range(max_retries + 1): try: endpoint = await self._wait_for_endpoint() if attempt > 0: await asyncio.sleep(1) async with self.session.post(f"{endpoint.rstrip('/')}/execute", json=data, headers={"Content-Type": "application/json"}) as response: status = response.status res_json = await response.json(content_type=None) if status != 200: raise PistonError(f"Server error. status={status}") if res_json is None: raise PistonError(f"Empty response. status={status}") # piston overloaded if 'run' in res_json and "Resource temporarily unavailable" in res_json['run'].get('stderr', ''): raise PistonError(f"Piston overloaded: {res_json['run']['stderr']}") return res_json except (PistonError, asyncio.TimeoutError, aiohttp.ClientConnectionError, RuntimeError) as e: # Only retry if we haven't reached max retries yet if attempt < max_retries: # Calculate backoff with jitter delay = min(base_delay * (2 ** attempt), 10) # Exponential backoff, capped at 10 seconds jitter = delay * 0.2 * (2 * asyncio.get_event_loop().time() % 1 - 0.5) # Add ±10% jitter retry_delay = delay + jitter print(f"Retrying in {retry_delay:.2f} seconds [{self.endpoint_ids[endpoint]}] {endpoint}") # special case: worker died if isinstance(e, aiohttp.ClientConnectionError) and "Connect call failed" in str(e): await self._check_failed_endpoint(endpoint) else: # hopefully we won't get this one again await self._release_endpoint(endpoint) endpoint = None await asyncio.sleep(retry_delay) else: print(f"Giving up on retries. {e}") raise e except Exception as e: print(f"Propagating exception {type(e)}: {e}") raise e finally: # Ensure endpoint is always released, even if an exception occurs if endpoint is not None: try: await self._release_endpoint(endpoint) except Exception as e: print(f"Error releasing endpoint {endpoint}: {e}") endpoint = None def get_slurm_piston_endpoints(): """Get list of active piston worker endpoints from squeue output""" # Run squeue command to get job name, hostname and status, filtering for RUNNING state result = subprocess.run(['squeue', '--format="%j %N %T"', '--noheader', '--states=RUNNING'], capture_output=True, text=True) # Split output into lines and skip header lines = result.stdout.strip().split('\n') endpoints = [] for line in lines: # Parse job name from squeue output fields = line.split() job_name = fields[0].strip('"') # Remove quotes hostname = fields[1] # Extract port if job name matches pattern match = re.match(r'piston-worker-(\d+)', job_name) if match: port = match.group(1) endpoints.append(f"http://{hostname}:{port}/api/v2") return endpoints