def _tune_for_js()

in src/sagemaker/serve/builder/jumpstart_builder.py [0:0]


    def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800):
        """Tune for Jumpstart Models in Local Mode.

        Args:
            sharded_supported (bool): Indicates whether sharding is supported by this ``Model``
            max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
                Default: ``1800``
        returns:
            Tuned Model.
        """
        if self.mode == Mode.SAGEMAKER_ENDPOINT:
            logger.warning(
                "Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
            )
            return self.pysdk_model

        num_shard_env_var_name = "SM_NUM_GPUS"
        if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys():
            num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"

        initial_env_vars = copy.deepcopy(self.pysdk_model.env)
        admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(
            self.js_model_config
        )

        if len(admissible_tensor_parallel_degrees) > 1 and not sharded_supported:
            admissible_tensor_parallel_degrees = [1]
            logger.warning(
                "Sharding across multiple GPUs is not supported for this model. "
                "Model can only be sharded across [1] GPU"
            )

        benchmark_results = {}
        best_tuned_combination = None
        timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
        for tensor_parallel_degree in admissible_tensor_parallel_degrees:
            if datetime.now() > timeout:
                logger.info("Max tuning duration reached. Tuning stopped.")
                break

            self.pysdk_model.env.update({num_shard_env_var_name: str(tensor_parallel_degree)})
            try:
                logger.info("Trying tensor parallel degree: %s", tensor_parallel_degree)

                predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration)

                avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
                    predictor, self.schema_builder.sample_input
                )
                throughput_per_second, standard_deviation = _concurrent_benchmark(
                    predictor, self.schema_builder.sample_input
                )

                tested_env = copy.deepcopy(self.pysdk_model.env)
                logger.info(
                    "Average latency: %s, throughput/s: %s for configuration: %s",
                    avg_latency,
                    throughput_per_second,
                    tested_env,
                )
                benchmark_results[avg_latency] = [
                    tested_env,
                    p90,
                    avg_tokens_per_second,
                    throughput_per_second,
                    standard_deviation,
                ]

                if not best_tuned_combination:
                    best_tuned_combination = [
                        avg_latency,
                        tensor_parallel_degree,
                        None,
                        p90,
                        avg_tokens_per_second,
                        throughput_per_second,
                        standard_deviation,
                    ]
                else:
                    tuned_configuration = [
                        avg_latency,
                        tensor_parallel_degree,
                        None,
                        p90,
                        avg_tokens_per_second,
                        throughput_per_second,
                        standard_deviation,
                    ]
                    if _more_performant(best_tuned_combination, tuned_configuration):
                        best_tuned_combination = tuned_configuration
            except LocalDeepPingException as e:
                logger.warning(
                    "Deployment unsuccessful with %s: %s. " "Failed to invoke the model server: %s",
                    num_shard_env_var_name,
                    tensor_parallel_degree,
                    str(e),
                )
            except LocalModelOutOfMemoryException as e:
                logger.warning(
                    "Deployment unsuccessful with %s: %s. "
                    "Out of memory when loading the model: %s",
                    num_shard_env_var_name,
                    tensor_parallel_degree,
                    str(e),
                )
            except LocalModelInvocationException as e:
                logger.warning(
                    "Deployment unsuccessful with %s: %s. "
                    "Failed to invoke the model server: %s"
                    "Please check that model server configurations are as expected "
                    "(Ex. serialization, deserialization, content_type, accept).",
                    num_shard_env_var_name,
                    tensor_parallel_degree,
                    str(e),
                )
            except LocalModelLoadException as e:
                logger.warning(
                    "Deployment unsuccessful with %s: %s. " "Failed to load the model: %s.",
                    num_shard_env_var_name,
                    tensor_parallel_degree,
                    str(e),
                )
            except SkipTuningComboException as e:
                logger.warning(
                    "Deployment with %s: %s"
                    "was expected to be successful. However failed with: %s. "
                    "Trying next combination.",
                    num_shard_env_var_name,
                    tensor_parallel_degree,
                    str(e),
                )
            except Exception:  # pylint: disable=W0703
                logger.exception(
                    "Deployment unsuccessful with %s: %s. " "with uncovered exception",
                    num_shard_env_var_name,
                    tensor_parallel_degree,
                )

        if best_tuned_combination:
            self.pysdk_model.env.update({num_shard_env_var_name: str(best_tuned_combination[1])})

            _pretty_print_results_jumpstart(benchmark_results, [num_shard_env_var_name])
            logger.info(
                "Model Configuration: %s was most performant with avg latency: %s, "
                "p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
                "standard deviation of request %s",
                self.pysdk_model.env,
                best_tuned_combination[0],
                best_tuned_combination[3],
                best_tuned_combination[4],
                best_tuned_combination[5],
                best_tuned_combination[6],
            )
        else:
            self.pysdk_model.env.update(initial_env_vars)
            logger.debug(
                "Failed to gather any tuning results. "
                "Please inspect the stack trace emitted from live logging for more details. "
                "Falling back to default model configurations: %s",
                self.pysdk_model.env,
            )

        return self.pysdk_model