in src/sagemaker/serve/builder/tgi_builder.py [0:0]
def _tune_for_hf_tgi(self, max_tuning_duration: int = 1800):
"""Placeholder docstring"""
if self.mode != Mode.LOCAL_CONTAINER:
logger.warning(
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
)
return self.pysdk_model
admissible_num_shard = _get_admissible_tensor_parallel_degrees(self.hf_model_config)
admissible_dtypes = _get_admissible_dtypes()
benchmark_results = {}
best_tuned_combination = None
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
for num_shard in admissible_num_shard:
if datetime.now() > timeout:
logger.info("Max tuning duration reached. Tuning stopped.")
break
dtype_passes = 0
for dtype in admissible_dtypes:
logger.info("Trying num shard: %s, dtype: %s...", num_shard, dtype)
if num_shard == 1:
self.env_vars.update({"SHARDED": "false"})
else:
self.env_vars.update({"SHARDED": "true"})
self.env_vars.update({"NUM_SHARD": str(num_shard), "DTYPE": dtype})
self.pysdk_model = self._create_tgi_model()
try:
predictor = self.pysdk_model.deploy(
model_data_download_timeout=max_tuning_duration
)
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
predictor, self.schema_builder.sample_input
)
throughput_per_second, standard_deviation = _concurrent_benchmark(
predictor, self.schema_builder.sample_input
)
tested_env = self.pysdk_model.env.copy()
logger.info(
"Average latency: %s, throughput/s: %s for configuration: %s",
avg_latency,
throughput_per_second,
tested_env,
)
benchmark_results[avg_latency] = [
tested_env,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
if not best_tuned_combination:
best_tuned_combination = [
avg_latency,
num_shard,
dtype,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
else:
tuned_configuration = [
avg_latency,
num_shard,
dtype,
p90,
avg_tokens_per_second,
throughput_per_second,
standard_deviation,
]
if _more_performant(best_tuned_combination, tuned_configuration):
best_tuned_combination = tuned_configuration
except LocalDeepPingException as e:
logger.warning(
"Deployment unsuccessful with num shard: %s. dtype: %s. "
"Failed to invoke the model server: %s",
num_shard,
dtype,
str(e),
)
break
except LocalModelOutOfMemoryException as e:
logger.warning(
"Deployment unsuccessful with num shard: %s, dtype: %s. "
"Out of memory when loading the model: %s",
num_shard,
dtype,
str(e),
)
break
except LocalModelInvocationException as e:
logger.warning(
"Deployment unsuccessful with num shard: %s, dtype: %s. "
"Failed to invoke the model server: %s"
"Please check that model server configurations are as expected "
"(Ex. serialization, deserialization, content_type, accept).",
num_shard,
dtype,
str(e),
)
break
except LocalModelLoadException as e:
logger.warning(
"Deployment unsuccessful with num shard: %s, dtype: %s. "
"Failed to load the model: %s.",
num_shard,
dtype,
str(e),
)
break
except SkipTuningComboException as e:
logger.warning(
"Deployment with num shard: %s, dtype: %s "
"was expected to be successful. However failed with: %s. "
"Trying next combination.",
num_shard,
dtype,
str(e),
)
except Exception: # pylint: disable=W0703
logger.exception(
"Deployment unsuccessful with num shard: %s, dtype: %s "
"with uncovered exception",
num_shard,
dtype,
)
break
dtype_passes += 1
if dtype_passes == 0:
logger.info(
"Lowest admissible num shard: %s and highest dtype: "
"%s combination has been attempted. Tuning stopped.",
num_shard,
dtype,
)
break
if best_tuned_combination:
if best_tuned_combination[1] == 1:
self.env_vars.update({"SHARDED": "false"})
else:
self.env_vars.update({"SHARDED": "true"})
self.env_vars.update(
{"NUM_SHARD": str(best_tuned_combination[1]), "DTYPE": best_tuned_combination[2]}
)
self.pysdk_model = self._create_tgi_model()
_pretty_print_results_tgi(benchmark_results)
logger.info(
"Model Configuration: %s was most performant with avg latency: %s, "
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
"standard deviation of request %s",
self.pysdk_model.env,
best_tuned_combination[0],
best_tuned_combination[3],
best_tuned_combination[4],
best_tuned_combination[5],
best_tuned_combination[6],
)
else:
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)
default_tgi_configurations, _default_max_new_tokens = _get_default_tgi_configurations(
self.model, self.hf_model_config, self.schema_builder
)
self.env_vars.update(default_tgi_configurations)
self.schema_builder.sample_input["parameters"][
"max_new_tokens"
] = _default_max_new_tokens
self.pysdk_model = self._create_tgi_model()
logger.debug(
"Failed to gather any tuning results. "
"Please inspect the stack trace emitted from live logging for more details. "
"Falling back to default serving.properties: %s",
self.pysdk_model.env,
)
return self.pysdk_model