from enum import Enum
import os
from jinja2 import Template
import subprocess
import json
from typing import List 

class Status(Enum):
    # INIT -> PENDING -> [RUNNING | FAIL | TIMEOUT OOM] -> COMPLETED
    INIT = "init"           # Job is created
    PENDING = "pending"     # Job is waiting for ressources
    RUNNING = "running"     # Job is running
    FAIL = "fail"           # Job failed
    OOM = "oom"             # Job failed due to out of memory (expected behavior)
    TIMEOUT = "timeout"     # Job failed due to timeout
    COMPLETED = "completed" # Job is completed

class Job:
    def __init__(self, root_path: str, qos: str) -> None:
        self.root_path = root_path        
        self.name = os.path.basename(root_path)
        self.config = os.path.join(root_path, "config.json")
        self.qos = qos
        
        # Check if the status.txt file exists
        status_file_path = os.path.join(self.root_path, "status.txt")
        if not os.path.exists(status_file_path):
            # Create the status.txt file with INIT status
            with open(status_file_path, 'w') as f:
                f.write(Status.INIT.value)
        self.status = self.get_status()
        
    def get_status(self) -> Status:
        """
        Read the status of the job from `status.txt` and return it
        """
        is_existing = lambda value_to_check: any(value.value == value_to_check for value in Status.__members__.values())

        status_file_path = os.path.join(self.root_path, "status.txt")
        with open(status_file_path, 'r') as f:
            status = f.read()
            if not is_existing(status):
                raise ValueError("Invalid status")
            return Status(status)
        
    def set_status(self, status: Status) -> Status:
        """
        Update the status of the job in `status.txt` and return the new status
        """
        status_file_path = os.path.join(self.root_path, "status.txt")
        with open(status_file_path, 'w') as f:
            f.write(status.value)
            return status        
    
class Scheduler:
    
    def __init__(self, inp_dir: str, qos: str) -> None:
        jobs_directory_paths = [os.path.abspath(root) for root, dirs, _ in os.walk(inp_dir) if not dirs]
        jobs_directory_paths = [job_path.replace("/profiler", "") if "profiler" in job_path else job_path for job_path in jobs_directory_paths]
        self.job_lists = [Job(job_path, qos) for job_path in jobs_directory_paths]

    def keep_only_jobs(self, status: Status):
        return [job for job in self.job_lists if job.status == status]
    
    def filter_out_jobs(self, status: Status):
        return [job for job in self.job_lists if job.status != status]
     
    def create_slurm_script(self, job: Job):
        # Submit job to the cluster (edit jinja)    
        # load yaml config.yaml
        with open(job.config, 'r') as file:
            config = json.load(file)
        
        max_gpu_per_node = 8
        # Pick the right number of nodes and n_proc_per_node
        world_size = config["distributed"]["tp_size"] * config["distributed"]["cp_size"] * config["distributed"]["pp_size"] * config["distributed"]["dp_size"]
        assert world_size <= max_gpu_per_node or world_size % max_gpu_per_node == 0
        nodes = max(1, world_size // max_gpu_per_node)
        n_proc_per_node = min(max_gpu_per_node, world_size // nodes)
        assert nodes * n_proc_per_node == world_size
        
        context_bench = {
            'nodes': nodes,
            'n_proc_per_node': n_proc_per_node,
            'root_path': job.root_path,
            "config": job.config,
            "qos": job.qos,
        }
        
        base_path = os.path.join(os.getcwd(), "template/base_job.slurm")

        with open(base_path, 'r') as file:
            base_job_file = file.read()
        
        base_job_template = Template(base_job_file)
                
        # Write the rendered script to a new file located at the job root_path
        output_file_path = os.path.join(job.root_path, "job.slurm")
        with open(output_file_path, 'w') as file:
            file.write(base_job_template.render(context_bench))

        print(f"Slurm script created at {output_file_path}")
    
    def launch_dependency(self, job_array: List[Job], env_vars):
        
        prev_job_id = None
        for job in job_array:
            if prev_job_id is None:         
                result = subprocess.run(["sbatch", '--parsable', os.path.join(job.root_path, "job.slurm")], env=env_vars, capture_output=True, text=True)
            else:
                result = subprocess.run(["sbatch", '--parsable', '--dependency=afterany:'+prev_job_id, os.path.join(job.root_path, "job.slurm")], env=env_vars, capture_output=True, text=True)
            job.set_status(Status.PENDING)
            prev_job_id = result.stdout.strip()
                    

    def check_status(self):
        # find all status files using self.jobs_directory_paths
        status_files = [os.path.join(job.root_path, "status.txt") for job in self.job_lists]
                
        status_counts = {
            "init": 0,
            "pending": 0,
            "running": 0,
            "fail": 0,
            "oom": 0,
            "timeout": 0,
            "completed": 0
        }
        
        for status_file in status_files:
            with open(status_file, 'r') as f:
                status = f.read().strip()
                if status in status_counts:
                    status_counts[status] += 1
                else:
                    raise ValueError(f"Invalid status: {status}")

        total = sum(status_counts.values())
        
        # Print the status counts in a formatted table
        print(f"{'Status':<10} | {'Count':<6}")
        print(f"{'-'*10}-|-{'-'*6}")
        for status, count in status_counts.items():
            print(f"{status.capitalize():<10} | {count:<6}")
        
        print(f"{'-'*10}-|-{'-'*6}")
        print(f"{'Total':<10} | {total:<6}")

def submit_jobs(inp_dir, qos, hf_token, nb_slurm_array, only: str = None):
    scheduler = Scheduler(inp_dir, qos)

    #TODO: batch into job arrays
    env_vars = os.environ.copy()
    env_vars["HUGGINGFACE_TOKEN"] = hf_token
    total_jobs = len(scheduler.job_lists)

    if only == "fail":
        scheduler.job_lists = scheduler.keep_only_jobs(Status.FAIL)
    elif only == "pending":
        scheduler.job_lists = scheduler.keep_only_jobs(Status.PENDING)
    elif only == "timeout":
        scheduler.job_lists = scheduler.keep_only_jobs(Status.TIMEOUT)
    elif only == "running":
        scheduler.job_lists = scheduler.keep_only_jobs(Status.RUNNING)
    
    if only is not None:
        filtered_jobs = len(scheduler.job_lists)
        if filtered_jobs == 0:
            print(f"No '{only}' jobs to resubmit")
            return
        print(f"Only {filtered_jobs}/{total_jobs} jobs with status '{only}' will be resubmitted")
    
    scheduler.job_lists = scheduler.filter_out_jobs(Status.COMPLETED)

    if nb_slurm_array > 0:
        # Use job dependecies
        
        # Distribute the jobs into the arrays        
        base_jobs_per_array = len(scheduler.job_lists) // nb_slurm_array
        extra_jobs = len(scheduler.job_lists) % nb_slurm_array
        distribution = [base_jobs_per_array] * nb_slurm_array
        for i in range(extra_jobs):
            distribution[i] += 1
        
        start = 0
        
        for i, nb_jobs in enumerate(distribution):
            previous_job_id = None
            end = start + nb_jobs
            job_array = scheduler.job_lists[start:end]
            
            print(f"Launching job Dependency array {i+1} with {nb_jobs} jobs")
            
            for job in job_array:
                scheduler.create_slurm_script(job)

            scheduler.launch_dependency(job_array, env_vars)
            
            start = end
    else:
        # Don't use job dependecies
        for job in scheduler.job_lists:
            scheduler.create_slurm_script(job)
            print(os.path.join(job.root_path, "job.slurm"))
            subprocess.run(["sbatch", os.path.join(job.root_path, "job.slurm")], env=env_vars)
            job.set_status(Status.PENDING)
            
if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Submit jobs to the cluster')
    parser.add_argument('--inp_dir', type=str, help='Input directory containing the jobs')
    parser.add_argument('--qos', type=str, help='QOS of the jobs')
    parser.add_argument('--nb_slurm_array', type=int, default=0, help='Number of slurm arrays')
    parser.add_argument('--only', type=str, default=None, help='Filter the jobs to submit')
    parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')

    args = parser.parse_args()
    #TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...."
    submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)
