in src/smspark/bootstrapper.py [0:0]
def set_yarn_spark_resource_config(self) -> None:
processing_job_config = self.load_processing_job_config()
instance_type_info = self.load_instance_type_info()
if processing_job_config and instance_type_info:
instance_type = processing_job_config["ProcessingResources"]["ClusterConfig"]["InstanceType"].replace(
"ml.", ""
)
instance_count = processing_job_config["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
instance_type_info = instance_type_info[instance_type]
instance_mem_mb = instance_type_info["MemoryInfo"]["SizeInMiB"]
instance_cores = instance_type_info["VCpuInfo"]["DefaultVCpus"]
logging.info(
f"Detected instance type for processing: {instance_type} with "
f"total memory: {instance_mem_mb}M and total cores: {instance_cores}"
)
elif all(key in self.resource_config for key in ["current_instance_type", "hosts"]) and instance_type_info:
# TODO: Support training heterogeneous cluster with instance groups
instance_type = self.resource_config["current_instance_type"].replace("ml.", "")
instance_count = len(self.resource_config["hosts"])
instance_type_info = instance_type_info[instance_type]
instance_mem_mb = instance_type_info["MemoryInfo"]["SizeInMiB"]
instance_cores = instance_type_info["VCpuInfo"]["DefaultVCpus"]
logging.info(
f"Detected instance type for training: {instance_type} with "
f"total memory: {instance_mem_mb}M and total cores: {instance_cores}"
)
else:
instance_count = 1
instance_mem_mb = int(psutil.virtual_memory().total / (1024 * 1024))
instance_cores = psutil.cpu_count(logical=True)
logging.warning(
f"Failed to detect instance type config. "
f"Found total memory: {instance_mem_mb}M and total cores: {instance_cores}"
)
yarn_config, spark_config = self.get_yarn_spark_resource_config(instance_count, instance_mem_mb, instance_cores)
logging.info("Writing default config to {}".format(yarn_config.path))
yarn_config_string = yarn_config.write_config()
logging.info("Configuration at {} is: \n{}".format(yarn_config.path, yarn_config_string))
logging.info("Writing default config to {}".format(spark_config.path))
spark_config_string = spark_config.write_config()
logging.info("Configuration at {} is: \n{}".format(spark_config.path, spark_config_string))