generate/run_ioi_slurm.py (112 lines of code) (raw):
#!/usr/bin/env python3
from math import ceil, gcd
import os
import argparse
import subprocess
from pathlib import Path
from transformers import AutoConfig
import logging
logger = logging.getLogger(__name__)
DEFAULT_TP = 16
MAX_CTX_LENGTH = None
MODEL_CONFIGS = {}
LOGS_DIR = "/fsx/hynek_kydlicek/logs/ioi-eval"
SLURM_SCRIPT_DIR = "/fsx/hynek_kydlicek/slurm/ioi-eval/output"
UV_ENV = "/fsx/hynek_kydlicek/projects/ioi-leaderboard/ioi-eval"
def get_concurrency(model_name: str, concurrency: int) -> int:
"""Get concurrency from model config."""
return MODEL_CONFIGS.get(model_name, {}).get("concurrency", concurrency)
def get_tp(model_name: str, revision: str) -> int:
default_tp = MODEL_CONFIGS.get(model_name, {}).get("tp", DEFAULT_TP)
try:
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
# Check num_attention_heads and num_key_value_heads, and ensure that both are divisable by tp
if hasattr(config, 'num_attention_heads'):
if config.num_attention_heads % default_tp != 0:
# Adjust tp to be the highest number that divides both num_attention_heads
new_tp = gcd(config.num_attention_heads, default_tp)
print(f"Adjusted tp for {model_name} from {default_tp} to {new_tp}")
return new_tp
return default_tp
except Exception as e:
print(f"Could not get tp from config for {model_name}: {e}")
return default_tp
def get_context_length(model_name: str, revision: str) -> int:
"""Get maximum context length from model config."""
try:
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
# Check various possible context length attributes
context_length = (
getattr(config, 'max_position_embeddings', None) or
getattr(config, 'sliding_window', None) or
getattr(config, 'max_sequence_length', None) or
getattr(config, 'max_seq_len', None) or
4096 # Default fallback
)
# Some models (like Qwen) might have sliding_window disabled
if hasattr(config, 'use_sliding_window') and not config.use_sliding_window:
# If sliding window is disabled, use max_position_embeddings instead
context_length = getattr(config, 'max_position_embeddings', context_length)
# cap to 64k
if MAX_CTX_LENGTH is not None:
context_length = min(context_length, MAX_CTX_LENGTH)
return context_length
except Exception as e:
logger.warning(f"Could not get context length from config for {model_name}: {e}")
return 4096 # Default fallback
def parse_args():
parser = argparse.ArgumentParser(description="Run IOI evaluation on a model using Slurm")
parser.add_argument("--model", type=str, required=True,
help="Model to evaluate (predefined model name)")
parser.add_argument("--eval_args", type=str, required=True,
help="Arguments to pass to the evaluation script")
parser.add_argument("--time", type=str, default="7-00:00:00",
help="Job time limit (default: 7 days)")
parser.add_argument("--partition", type=str, default="hopper-prod",
help="Slurm partition")
parser.add_argument("--qos", type=str, default="normal",
help="Slurm QOS")
parser.add_argument("--startup_delay", type=int, default=3600,
help="Delay in seconds before starting the server")
parser.add_argument("--dry_run", action="store_true",
help="Generate script but don't submit job")
parser.add_argument("--revision", type=str, default=None, help="Revision to use for the model")
parser.add_argument("--concurrency", type=int, default=100,
help="Number of concurrent requests to the server")
parser.add_argument("--uv_env", type=str, default=None, help="Path to the uv env")
parser.add_argument("--logs_dir", type=str, default=None)
parser.add_argument("--slurm_dir", type=str, default=None)
return parser.parse_args()
def create_slurm_script(args, logs_dir):
# Override with custom values if provided
concurrency = get_concurrency(args.model, args.concurrency)
tp = get_tp(args.model, args.revision)
context_length = get_context_length(args.model, args.revision)
# Create a sanitized model name for the job name
job_name = f"ioi-eval-{args.model.replace('/', '-')}"
log_dir = logs_dir / job_name
log_dir.mkdir(parents=True, exist_ok=True)
n_nodes = ceil(tp / 8)
tasks = n_nodes
revision_arg = f"--revision {args.revision}" if args.revision else ""
slurm_script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --partition={args.partition}
#SBATCH --qos={args.qos}
#SBATCH --nodes={n_nodes}
#SBATCH --gpus-per-node=8
#SBATCH --exclusive
#SBATCH --output={log_dir}/%j-%x.out
#SBATCH --error={log_dir}/%j-%x.out
#SBATCH --time={args.time}
#SBATCH --ntasks-per-node=1
set -exuo pipefail
SERVER_PORT=39877
DIST_PORT=45000
# random sleep (0-100) to prevent ddosing server
sleep $((RANDOM % 100 + 1))
# Environment configuration
export OUTLINES_CACHE_DIR=/scratch/serve_r1/ocache/
export TRITON_HOME=/scratch/serve_r1/triton/
export GLOO_SOCKET_IFNAME="enp71s0"
export NCCL_SOCKET_IFNAME="enp71s0"
# Evaluation script path
EVAL_SCRIPT_PATH="/fsx/hynek_kydlicek/projects/ioi/generate/evaluate.py"
module load cuda/12.4
source ~/.bashrc
# Activate uv
source {args.uv_env or UV_ENV}/bin/activate
FIRST_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
FIRST_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$FIRST_NODE" hostname --ip-address)
# Launch servers synchronously across all nodes
srun --nodes={n_nodes} --ntasks={tasks} --ntasks-per-node=1 \\
bash -c "python -m sglang.launch_server \\
--model-path '{args.model}' \\
--tp {tp} \\
--dist-init-addr '$FIRST_NODE_IP:$DIST_PORT' \\
{revision_arg} \\
--nnodes {n_nodes} \\
--node-rank \\$SLURM_PROCID \\
--port '$SERVER_PORT' \\
--host 0.0.0.0 \\
--trust-remote-code \\
--max-running-requests {concurrency} \\
--context-length {context_length}" &
# Wait for server with timeout
TIMEOUT={args.startup_delay} # 1h, but model loading should take ~30min
START_TIME=$(date +%s)
echo "Waiting for SGLang server (http://$FIRST_NODE_IP:$SERVER_PORT)..."
while true; do
if curl -s -o /dev/null -w "%{{http_code}}" "http://$FIRST_NODE_IP:$SERVER_PORT/health" >/dev/null 2>&1; then
echo "Server is ready at http://$FIRST_NODE_IP:$SERVER_PORT"
break
fi
CURRENT_TIME=$(date +%s)
if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then
echo "Error: Server failed to start within $TIMEOUT seconds"
exit 1
fi
echo "Still waiting... ($(($CURRENT_TIME - $START_TIME)) seconds elapsed)"
sleep 60
done
echo "Checking available models..."
curl "http://$FIRST_NODE_IP:$SERVER_PORT/v1/models"
sleep 10
echo "Executing sanity check..."
curl "http://$FIRST_NODE_IP:$SERVER_PORT/v1/completions" \\
-H "Content-Type: application/json" \\
-d '{{
"model": "default",
"prompt": "hi, how are you?",
"max_tokens": 2048,
"temperature": 0.6
}}'
python "$EVAL_SCRIPT_PATH" \\
--model_id "sglang/{args.model}" \\
{revision_arg} \\
--api_base "http://localhost:$SERVER_PORT/v1" \\
--concurrency {concurrency} \\
{args.eval_args}
# Kill the server and exit
pkill -f "python -m sglang.launch_server"
exit 0
"""
return slurm_script, job_name
def main():
args = parse_args()
# Create output directory if it doesn't exist
output_dir = Path(args.slurm_dir or SLURM_SCRIPT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
# Create logs directory if it doesn't exist
logs_dir = Path(args.logs_dir or LOGS_DIR)
logs_dir.mkdir(parents=True, exist_ok=True)
# Generate the Slurm script
slurm_script, job_name = create_slurm_script(args, logs_dir)
# Create a timestamp for the filename
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Save the script to a file
script_path = output_dir / f"{job_name}_{timestamp}.slurm"
with open(script_path, "w") as f:
f.write(slurm_script)
logger.info(f"Slurm script saved to: {script_path}")
# Make the script executable
os.chmod(script_path, 0o755)
# Submit the job if not a dry run
if not args.dry_run:
try:
result = subprocess.run(
["sbatch", str(script_path)],
check=True,
capture_output=True,
text=True
)
print(f"Job submitted: {result.stdout.strip()} find logs at {LOGS_DIR}/{job_name}")
except subprocess.CalledProcessError as e:
print(f"Error submitting job: {e}")
print(f"Error output: {e.stderr}")
else:
print("Dry run - job not submitted")
if __name__ == "__main__":
main()