def create_slurm_script()

in bench_cluster/submit_jobs.py [0:0]


    def create_slurm_script(self, job: Job, cluster: str):
        # Submit job to the cluster (edit jinja)    
        # load yaml config.yaml
        with open(job.config, 'r') as file:
            config = yaml.load(file, Loader=yaml.FullLoader)
        
        if cluster == "hf":
            max_nodes = 8
        elif cluster == "swiss-ai":
            max_nodes = 4
        else:
            raise ValueError("Invalid cluster")
        
        # Pick the right number of nodes and n_proc_per_node
        world_size = config['parallelism']['pp'] * config['parallelism']['dp'] * config['parallelism']['tp']
        assert world_size <= max_nodes or world_size % max_nodes == 0
        nodes = max(1, world_size // max_nodes)
        n_proc_per_node = min(8, world_size // nodes)
        assert nodes * n_proc_per_node == world_size
        
        target_path_hf_hub = os.path.join(os.path.basename(os.path.dirname(os.path.dirname(job.root_path))), os.path.basename(os.path.dirname(job.root_path)), os.path.basename(job.root_path))
        
        context_bench = {
            'nodes': nodes,
            'n_proc_per_node': n_proc_per_node,
            'root_path': job.root_path,
            'target_path_hf_hub': target_path_hf_hub,
            "config": job.config,
            "qos": job.qos,
        }
        
        
        #TODO: don't hardcode the base_bench.slurm path. Should be #HOME/bench_cluster/template/base_bench.slurm
        if cluster == "swiss-ai":
            base_path = "/users/fmom/project/bench_cluster/bench_cluster/template/base_bench_swiss.slurm"
        elif cluster == "hf":
            # HF cluster
            base_path = "/fsx/ferdinandmom/ferdinand-hf/bench_cluster/bench_cluster/template/base_bench.slurm"
        else:
            raise ValueError("Invalid cluster")
        
        with open(base_path, 'r') as file:
            base_bench_file = file.read()
        
        base_bench_template = Template(base_bench_file)
                
        # Write the rendered script to a new file located at the job root_path
        output_file_path = os.path.join(job.root_path, "bench.slurm")
        with open(output_file_path, 'w') as file:
            file.write(base_bench_template.render(context_bench))

        print(f"Slurm script created at {output_file_path}")