in community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts/tpu.py [0:0]
def create_node(self, nodename):
if self.vmcount > 1 and not isinstance(nodename, list):
log.error(
f"Tried to create a {self.vmcount} node TPU on nodeset {self._nodeset.nodeset_name} but only received one nodename {nodename}"
)
return False
if self.vmcount > 1 and (
isinstance(nodename, list) and len(nodename) != self.vmcount
):
log.error(
f"Expected to receive a list of {self.vmcount} nodenames for TPU node creation in nodeset {self._nodeset.nodeset_name}, but received this list {nodename}"
)
return False
node = tpu.Node()
node.accelerator_config = self.ac
node.runtime_version = f"tpu-vm-tf-{self.tf_version}"
startup_script = """
#!/bin/bash
echo "startup script not found > /var/log/startup_error.log"
"""
with open(
Path(self.lkp.cfg.slurm_scripts_dir or util.dirs.scripts) / "startup.sh", "r"
) as script:
startup_script = script.read()
if isinstance(nodename, list):
node_id = nodename[0]
slurm_names = []
wid = 0
for node_wid in nodename:
slurm_names.append(f"WORKER_{wid}:{node_wid}")
wid += 1
else:
node_id = nodename
slurm_names = [f"WORKER_0:{nodename}"]
node.metadata = {
"slurm_docker_image": self.nodeset.docker_image,
"startup-script": startup_script,
"slurm_instance_role": "compute",
"slurm_cluster_name": self.lkp.cfg.slurm_cluster_name,
"slurm_bucket_path": self.lkp.cfg.bucket_path,
"slurm_names": ";".join(slurm_names),
"universe_domain": util.universe_domain(),
}
node.tags = [self.lkp.cfg.slurm_cluster_name]
if self.nodeset.service_account:
node.service_account.email = self.nodeset.service_account.email
node.service_account.scope = self.nodeset.service_account.scopes
node.scheduling_config.preemptible = self.preemptible
node.scheduling_config.reserved = self.reserved
node.network_config.subnetwork = self.nodeset.subnetwork
node.network_config.enable_external_ips = self.enable_public_ip
if self.data_disks:
node.data_disks = self.data_disks
request = tpu.CreateNodeRequest(parent=self._parent, node=node, node_id=node_id)
resp = self._client.create_node(request=request).result()
if not self.__check_resp(resp, "create"):
return False
if isinstance(nodename, list):
for node_id, net_endpoint in zip(nodename, resp.network_endpoints):
self._register_node(node_id, net_endpoint.ip_address)
else:
ip_add = resp.network_endpoints[0].ip_address
self._register_node(nodename, ip_add)
return True