in tensorflow/inference/docker/build_artifacts/sagemaker/tfs_utils.py [0:0]
def create_batching_config(batching_config_file):
class _BatchingParameter:
def __init__(self, key, env_var, value, defaulted_message):
self.key = key
self.env_var = env_var
self.value = value
self.defaulted_message = defaulted_message
cpu_count = multiprocessing.cpu_count()
batching_parameters = [
_BatchingParameter(
"max_batch_size",
"SAGEMAKER_TFS_MAX_BATCH_SIZE",
8,
"max_batch_size defaulted to {}. Set {} to override default. "
"Tuning this parameter may yield better performance.",
),
_BatchingParameter(
"batch_timeout_micros",
"SAGEMAKER_TFS_BATCH_TIMEOUT_MICROS",
1000,
"batch_timeout_micros defaulted to {}. Set {} to override "
"default. Tuning this parameter may yield better performance.",
),
_BatchingParameter(
"num_batch_threads",
"SAGEMAKER_TFS_NUM_BATCH_THREADS",
cpu_count,
"num_batch_threads defaulted to {}," "the number of CPUs. Set {} to override default.",
),
_BatchingParameter(
"max_enqueued_batches",
"SAGEMAKER_TFS_MAX_ENQUEUED_BATCHES",
# Batch limits number of concurrent requests, which limits number
# of enqueued batches, so this can be set high for Batch
100000000 if "SAGEMAKER_BATCH" in os.environ else cpu_count,
"max_enqueued_batches defaulted to {}. Set {} to override default. "
"Tuning this parameter may be necessary to tune out-of-memory "
"errors occur.",
),
]
warning_message = ""
for batching_parameter in batching_parameters:
if batching_parameter.env_var in os.environ:
batching_parameter.value = os.environ[batching_parameter.env_var]
else:
warning_message += batching_parameter.defaulted_message.format(
batching_parameter.value, batching_parameter.env_var
)
warning_message += "\n"
if warning_message:
log.warning(warning_message)
config = ""
for batching_parameter in batching_parameters:
config += "%s { value: %s }\n" % (batching_parameter.key, batching_parameter.value)
log.info("batching config: \n%s\n", config)
with open(batching_config_file, "w", encoding="utf8") as f:
f.write(config)