def create_batching_config()

in tensorflow/inference/docker/build_artifacts/sagemaker_neuron/tfs_utils.py [0:0]


def create_batching_config(batching_config_file):
    class _BatchingParameter:
        def __init__(self, key, env_var, value, defaulted_message):
            self.key = key
            self.env_var = env_var
            self.value = value
            self.defaulted_message = defaulted_message

    cpu_count = multiprocessing.cpu_count()
    batching_parameters = [
        _BatchingParameter(
            "max_batch_size",
            "SAGEMAKER_TFS_MAX_BATCH_SIZE",
            8,
            "max_batch_size defaulted to {}. Set {} to override default. "
            "Tuning this parameter may yield better performance.",
        ),
        _BatchingParameter(
            "batch_timeout_micros",
            "SAGEMAKER_TFS_BATCH_TIMEOUT_MICROS",
            1000,
            "batch_timeout_micros defaulted to {}. Set {} to override "
            "default. Tuning this parameter may yield better performance.",
        ),
        _BatchingParameter(
            "num_batch_threads",
            "SAGEMAKER_TFS_NUM_BATCH_THREADS",
            cpu_count,
            "num_batch_threads defaulted to {}," "the number of CPUs. Set {} to override default.",
        ),
        _BatchingParameter(
            "max_enqueued_batches",
            "SAGEMAKER_TFS_MAX_ENQUEUED_BATCHES",
            # Batch limits number of concurrent requests, which limits number
            # of enqueued batches, so this can be set high for Batch
            100000000 if "SAGEMAKER_BATCH" in os.environ else cpu_count,
            "max_enqueued_batches defaulted to {}. Set {} to override default. "
            "Tuning this parameter may be necessary to tune out-of-memory "
            "errors occur.",
        ),
    ]

    warning_message = ""
    for batching_parameter in batching_parameters:
        if batching_parameter.env_var in os.environ:
            batching_parameter.value = os.environ[batching_parameter.env_var]
        else:
            warning_message += batching_parameter.defaulted_message.format(
                batching_parameter.value, batching_parameter.env_var
            )
            warning_message += "\n"
    if warning_message:
        log.warning(warning_message)

    config = ""
    for batching_parameter in batching_parameters:
        config += "%s { value: %s }\n" % (batching_parameter.key, batching_parameter.value)

    log.info("batching config: \n%s\n", config)
    with open(batching_config_file, "w", encoding="utf8") as f:
        f.write(config)