in src/sagemaker/serve/builder/djl_builder.py [0:0]
def _tune_for_hf_djl(self, max_tuning_duration: int = 1800):
"""Placeholder docstring"""
if self.mode != Mode.LOCAL_CONTAINER:
logger.warning(
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
)
return self.pysdk_model
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(
self.hf_model_config
)
admissible_dtypes = _get_admissible_dtypes()
benchmark_results = {}
best_tuned_combination = None
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
if datetime.now() > timeout:
logger.info("Max tuning duration reached. Tuning stopped.")
break
dtype_passes = 0
for dtype in admissible_dtypes:
logger.info(
"Trying tensor parallel degree: %s, dtype: %s...", tensor_parallel_degree, dtype
)
self.env_vars.update(
{"TENSOR_PARALLEL_DEGREE": str(tensor_parallel_degree), "OPTION_DTYPE": dtype}
)
self.pysdk_model = self._create_djl_model()
try:
predictor = self.pysdk_model.deploy(
model_data_download_timeout=max_tuning_duration
)
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
predictor, self.schema_builder.sample_input
)
throughput_per_second, standard_deviation = _concurrent_benchmark(
predictor, self.schema_builder.sample_input
)
tested_env = self.pysdk_model.env.copy()
logger.info(
"Average latency: %s, throughput/s: %s for configuration: %s",
avg_latency,
throughput_per_second,
tested_env,
)
benchmark_results[avg_latency] = [
tested_env,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
if not best_tuned_combination:
best_tuned_combination = [
avg_latency,
tensor_parallel_degree,
dtype,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
else:
tuned_configuration = [
avg_latency,
tensor_parallel_degree,
dtype,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
if _more_performant(best_tuned_combination, tuned_configuration):
best_tuned_combination = tuned_configuration
except LocalDeepPingException as e:
logger.warning(
"Deployment unsuccessful with tensor parallel degree: %s. dtype: %s. "
"Failed to invoke the model server: %s",
tensor_parallel_degree,
dtype,
str(e),
)
break
except LocalModelOutOfMemoryException as e:
logger.warning(
"Deployment unsuccessful with tensor parallel degree: %s, dtype: %s. "
"Out of memory when loading the model: %s",
tensor_parallel_degree,
dtype,
str(e),
)
break
except LocalModelInvocationException as e:
logger.warning(
"Deployment unsuccessful with tensor parallel degree: %s, dtype: %s. "
"Failed to invoke the model server: %s"
"Please check that model server configurations are as expected "
"(Ex. serialization, deserialization, content_type, accept).",
tensor_parallel_degree,
dtype,
str(e),
)
break
except LocalModelLoadException as e:
logger.warning(
"Deployment unsuccessful with tensor parallel degree: %s, dtype: %s. "
"Failed to load the model: %s.",
tensor_parallel_degree,
dtype,
str(e),
)
break
except Exception: # pylint: disable=W0703
logger.exception(
"Deployment unsuccessful with tensor parallel degree: %s, dtype: %s "
"with uncovered exception",
tensor_parallel_degree,
dtype,
)
break
dtype_passes += 1
if dtype_passes == 0:
logger.info(
"Lowest admissible tensor parallel degree: %s and highest dtype: "
"%s combination has been attempted. Tuning stopped.",
tensor_parallel_degree,
dtype,
)
break
if best_tuned_combination:
self._default_tensor_parallel_degree = best_tuned_combination[1]
self._default_data_type = best_tuned_combination[2]
self.env_vars.update(
{
"TENSOR_PARALLEL_DEGREE": str(self._default_tensor_parallel_degree),
"OPTION_DTYPE": self._default_data_type,
}
)
self.pysdk_model = self._create_djl_model()
_pretty_print_results(benchmark_results)
logger.info(
"Model Configuration: %s was most performant with avg latency: %s, "
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
"standard deviation of request %s",
self.pysdk_model.env,
best_tuned_combination[0],
best_tuned_combination[3],
best_tuned_combination[4],
best_tuned_combination[5],
best_tuned_combination[6],
)
else:
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
self.model, self.hf_model_config, self.schema_builder
)
self.env_vars.update(default_djl_configurations)
self.schema_builder.sample_input["parameters"][
"max_new_tokens"
] = _default_max_new_tokens
self.pysdk_model = self._create_djl_model()
logger.debug(
"Failed to gather any tuning results. "
"Please inspect the stack trace emitted from live logging for more details. "
"Falling back to default serving.properties: %s",
self.pysdk_model.env,
)
return self.pysdk_model