def create_batching_config()

in tensorflow/inference/docker/build_artifacts/sagemaker_neuron/tfs_utils.py [0:0]


def create_batching_config(batching_config_file):
    class _BatchingParameter:
        def __init__(self, key, env_var, value, defaulted_message):
            self.key = key
            self.env_var = env_var
            self.value = value
            self.defaulted_message = defaulted_message

    cpu_count = multiprocessing.cpu_count()
    batching_parameters = [
        _BatchingParameter("max_batch_size", "SAGEMAKER_TFS_MAX_BATCH_SIZE", 8,
                           "max_batch_size defaulted to {}. Set {} to override default. "
                           "Tuning this parameter may yield better performance."),
        _BatchingParameter("batch_timeout_micros", "SAGEMAKER_TFS_BATCH_TIMEOUT_MICROS", 1000,
                           "batch_timeout_micros defaulted to {}. Set {} to override "
                           "default. Tuning this parameter may yield better performance."),
        _BatchingParameter("num_batch_threads", "SAGEMAKER_TFS_NUM_BATCH_THREADS",
                           cpu_count, "num_batch_threads defaulted to {},"
                                      "the number of CPUs. Set {} to override default."),
        _BatchingParameter("max_enqueued_batches", "SAGEMAKER_TFS_MAX_ENQUEUED_BATCHES",
                           # Batch limits number of concurrent requests, which limits number
                           # of enqueued batches, so this can be set high for Batch
                           100000000 if "SAGEMAKER_BATCH" in os.environ else cpu_count,
                           "max_enqueued_batches defaulted to {}. Set {} to override default. "
                           "Tuning this parameter may be necessary to tune out-of-memory "
                           "errors occur."),
    ]

    warning_message = ""
    for batching_parameter in batching_parameters:
        if batching_parameter.env_var in os.environ:
            batching_parameter.value = os.environ[batching_parameter.env_var]
        else:
            warning_message += batching_parameter.defaulted_message.format(
                batching_parameter.value, batching_parameter.env_var)
            warning_message += "\n"
    if warning_message:
        log.warning(warning_message)

    config = ""
    for batching_parameter in batching_parameters:
        config += "%s { value: %s }\n" % (batching_parameter.key, batching_parameter.value)

    log.info("batching config: \n%s\n", config)
    with open(batching_config_file, "w") as f:
        f.write(config)