def set_yarn_spark_resource_config()

in src/smspark/bootstrapper.py [0:0]


    def set_yarn_spark_resource_config(self) -> None:
        processing_job_config = self.load_processing_job_config()
        instance_type_info = self.load_instance_type_info()

        if processing_job_config and instance_type_info:
            instance_type = processing_job_config["ProcessingResources"]["ClusterConfig"]["InstanceType"].replace(
                "ml.", ""
            )
            instance_count = processing_job_config["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
            instance_type_info = instance_type_info[instance_type]
            instance_mem_mb = instance_type_info["MemoryInfo"]["SizeInMiB"]
            instance_cores = instance_type_info["VCpuInfo"]["DefaultVCpus"]
            logging.info(
                f"Detected instance type for processing: {instance_type} with "
                f"total memory: {instance_mem_mb}M and total cores: {instance_cores}"
            )
        elif all(key in self.resource_config for key in ["current_instance_type", "hosts"]) and instance_type_info:
            # TODO: Support training heterogeneous cluster with instance groups
            instance_type = self.resource_config["current_instance_type"].replace("ml.", "")
            instance_count = len(self.resource_config["hosts"])
            instance_type_info = instance_type_info[instance_type]
            instance_mem_mb = instance_type_info["MemoryInfo"]["SizeInMiB"]
            instance_cores = instance_type_info["VCpuInfo"]["DefaultVCpus"]
            logging.info(
                f"Detected instance type for training: {instance_type} with "
                f"total memory: {instance_mem_mb}M and total cores: {instance_cores}"
            )
        else:
            instance_count = 1
            instance_mem_mb = int(psutil.virtual_memory().total / (1024 * 1024))
            instance_cores = psutil.cpu_count(logical=True)
            logging.warning(
                f"Failed to detect instance type config. "
                f"Found total memory: {instance_mem_mb}M and total cores: {instance_cores}"
            )

        yarn_config, spark_config = self.get_yarn_spark_resource_config(instance_count, instance_mem_mb, instance_cores)

        logging.info("Writing default config to {}".format(yarn_config.path))
        yarn_config_string = yarn_config.write_config()
        logging.info("Configuration at {} is: \n{}".format(yarn_config.path, yarn_config_string))

        logging.info("Writing default config to {}".format(spark_config.path))
        spark_config_string = spark_config.write_config()
        logging.info("Configuration at {} is: \n{}".format(spark_config.path, spark_config_string))