def _tune_for_hf_tgi()

in src/sagemaker/serve/builder/tgi_builder.py [0:0]


    def _tune_for_hf_tgi(self, max_tuning_duration: int = 1800):
        """Placeholder docstring"""
        if self.mode != Mode.LOCAL_CONTAINER:
            logger.warning(
                "Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
            )
            return self.pysdk_model

        admissible_num_shard = _get_admissible_tensor_parallel_degrees(self.hf_model_config)
        admissible_dtypes = _get_admissible_dtypes()

        benchmark_results = {}
        best_tuned_combination = None
        timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
        for num_shard in admissible_num_shard:
            if datetime.now() > timeout:
                logger.info("Max tuning duration reached. Tuning stopped.")
                break

            dtype_passes = 0
            for dtype in admissible_dtypes:
                logger.info("Trying num shard: %s, dtype: %s...", num_shard, dtype)

                if num_shard == 1:
                    self.env_vars.update({"SHARDED": "false"})
                else:
                    self.env_vars.update({"SHARDED": "true"})
                self.env_vars.update({"NUM_SHARD": str(num_shard), "DTYPE": dtype})
                self.pysdk_model = self._create_tgi_model()

                try:
                    predictor = self.pysdk_model.deploy(
                        model_data_download_timeout=max_tuning_duration
                    )

                    avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
                        predictor, self.schema_builder.sample_input
                    )
                    throughput_per_second, standard_deviation = _concurrent_benchmark(
                        predictor, self.schema_builder.sample_input
                    )

                    tested_env = self.pysdk_model.env.copy()
                    logger.info(
                        "Average latency: %s, throughput/s: %s for configuration: %s",
                        avg_latency,
                        throughput_per_second,
                        tested_env,
                    )
                    benchmark_results[avg_latency] = [
                        tested_env,
                        p90,
                        avg_tokens_per_second,
                        throughput_per_second,
                        standard_deviation,
                    ]

                    if not best_tuned_combination:
                        best_tuned_combination = [
                            avg_latency,
                            num_shard,
                            dtype,
                            p90,
                            avg_tokens_per_second,
                            throughput_per_second,
                            standard_deviation,
                        ]
                    else:
                        tuned_configuration = [
                            avg_latency,
                            num_shard,
                            dtype,
                            p90,
                            avg_tokens_per_second,
                            throughput_per_second,
                            standard_deviation,
                        ]
                        if _more_performant(best_tuned_combination, tuned_configuration):
                            best_tuned_combination = tuned_configuration
                except LocalDeepPingException as e:
                    logger.warning(
                        "Deployment unsuccessful with num shard: %s. dtype: %s. "
                        "Failed to invoke the model server: %s",
                        num_shard,
                        dtype,
                        str(e),
                    )
                    break
                except LocalModelOutOfMemoryException as e:
                    logger.warning(
                        "Deployment unsuccessful with num shard: %s, dtype: %s. "
                        "Out of memory when loading the model: %s",
                        num_shard,
                        dtype,
                        str(e),
                    )
                    break
                except LocalModelInvocationException as e:
                    logger.warning(
                        "Deployment unsuccessful with num shard: %s, dtype: %s. "
                        "Failed to invoke the model server: %s"
                        "Please check that model server configurations are as expected "
                        "(Ex. serialization, deserialization, content_type, accept).",
                        num_shard,
                        dtype,
                        str(e),
                    )
                    break
                except LocalModelLoadException as e:
                    logger.warning(
                        "Deployment unsuccessful with num shard: %s, dtype: %s. "
                        "Failed to load the model: %s.",
                        num_shard,
                        dtype,
                        str(e),
                    )
                    break
                except SkipTuningComboException as e:
                    logger.warning(
                        "Deployment with num shard: %s, dtype: %s "
                        "was expected to be successful. However failed with: %s. "
                        "Trying next combination.",
                        num_shard,
                        dtype,
                        str(e),
                    )
                except Exception:  # pylint: disable=W0703
                    logger.exception(
                        "Deployment unsuccessful with num shard: %s, dtype: %s "
                        "with uncovered exception",
                        num_shard,
                        dtype,
                    )
                    break
                dtype_passes += 1
            if dtype_passes == 0:
                logger.info(
                    "Lowest admissible num shard: %s and highest dtype: "
                    "%s combination has been attempted. Tuning stopped.",
                    num_shard,
                    dtype,
                )
                break

        if best_tuned_combination:
            if best_tuned_combination[1] == 1:
                self.env_vars.update({"SHARDED": "false"})
            else:
                self.env_vars.update({"SHARDED": "true"})
            self.env_vars.update(
                {"NUM_SHARD": str(best_tuned_combination[1]), "DTYPE": best_tuned_combination[2]}
            )

            self.pysdk_model = self._create_tgi_model()

            _pretty_print_results_tgi(benchmark_results)
            logger.info(
                "Model Configuration: %s was most performant with avg latency: %s, "
                "p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
                "standard deviation of request %s",
                self.pysdk_model.env,
                best_tuned_combination[0],
                best_tuned_combination[3],
                best_tuned_combination[4],
                best_tuned_combination[5],
                best_tuned_combination[6],
            )
        else:
            self.hf_model_config = _get_model_config_properties_from_hf(
                self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
            )
            default_tgi_configurations, _default_max_new_tokens = _get_default_tgi_configurations(
                self.model, self.hf_model_config, self.schema_builder
            )
            self.env_vars.update(default_tgi_configurations)
            self.schema_builder.sample_input["parameters"][
                "max_new_tokens"
            ] = _default_max_new_tokens
            self.pysdk_model = self._create_tgi_model()

            logger.debug(
                "Failed to gather any tuning results. "
                "Please inspect the stack trace emitted from live logging for more details. "
                "Falling back to default serving.properties: %s",
                self.pysdk_model.env,
            )

        return self.pysdk_model