in src/sagemaker/jumpstart/types.py [0:0]
def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json of header.
Args:
json_obj (Dict[str, Any]): Dictionary representation of spec.
"""
if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
self.model_id: str = json_obj.get("model_id")
self.url: str = json_obj.get("url")
self.version: str = json_obj.get("version")
self.min_sdk_version: str = json_obj.get("min_sdk_version")
self.incremental_training_supported: bool = bool(
json_obj.get("incremental_training_supported", False)
)
if self._is_hub_content:
self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
self.model_types: Optional[List[str]] = json_obj.get("model_types")
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
self._non_serializable_slots.append("hosting_ecr_specs")
else:
self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = (
JumpStartECRSpecs(
json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content
)
if "hosting_ecr_specs" in json_obj
else None
)
self._non_serializable_slots.append("hosting_ecr_uri")
self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key")
self.hosting_artifact_uri: Optional[str] = json_obj.get("hosting_artifact_uri")
self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key")
self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False))
self.inference_environment_variables = [
JumpStartEnvironmentVariable(env_variable, is_hub_content=self._is_hub_content)
for env_variable in json_obj.get("inference_environment_variables", [])
]
self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False))
self.inference_dependencies: List[str] = json_obj.get("inference_dependencies", [])
self.inference_vulnerabilities: List[str] = json_obj.get("inference_vulnerabilities", [])
self.training_vulnerable: bool = bool(json_obj.get("training_vulnerable", False))
self.training_dependencies: List[str] = json_obj.get("training_dependencies", [])
self.training_vulnerabilities: List[str] = json_obj.get("training_vulnerabilities", [])
self.deprecated: bool = bool(json_obj.get("deprecated", False))
self.deprecated_message: Optional[str] = json_obj.get("deprecated_message")
self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message")
self.usage_info_message: Optional[str] = json_obj.get("usage_info_message")
self.default_inference_instance_type: Optional[str] = json_obj.get(
"default_inference_instance_type"
)
self.default_training_instance_type: Optional[str] = json_obj.get(
"default_training_instance_type"
)
self.supported_inference_instance_types: Optional[List[str]] = json_obj.get(
"supported_inference_instance_types"
)
self.supported_training_instance_types: Optional[List[str]] = json_obj.get(
"supported_training_instance_types"
)
self.dynamic_container_deployment_supported: Optional[bool] = bool(
json_obj.get("dynamic_container_deployment_supported")
)
self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get(
"hosting_resource_requirements", None
)
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
self.training_prepacked_script_key: Optional[str] = json_obj.get(
"training_prepacked_script_key", None
)
self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get(
"hosting_prepacked_artifact_key", None
)
# New fields required for Hub model.
if self._is_hub_content:
self.training_prepacked_script_version: Optional[str] = json_obj.get(
"training_prepacked_script_version"
)
self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get(
"hosting_prepacked_artifact_version"
)
self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {}))
self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {}))
self.predictor_specs: Optional[JumpStartPredictorSpecs] = (
JumpStartPredictorSpecs(
json_obj.get("predictor_specs"),
is_hub_content=self._is_hub_content,
)
if json_obj.get("predictor_specs")
else None
)
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
{
alias: JumpStartSerializablePayload(payload, is_hub_content=self._is_hub_content)
for alias, payload in json_obj["default_payloads"].items()
}
if json_obj.get("default_payloads")
else None
)
self.gated_bucket = json_obj.get("gated_bucket", False)
self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size")
self.inference_enable_network_isolation: bool = json_obj.get(
"inference_enable_network_isolation", False
)
self.resource_name_base: bool = json_obj.get("resource_name_base")
self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")
model_package_arns = json_obj.get("hosting_model_package_arns")
self.hosting_model_package_arns: Optional[Dict] = (
model_package_arns if model_package_arns is not None else {}
)
self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True)
self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
JumpStartInstanceTypeVariants(
json_obj["hosting_instance_type_variants"], self._is_hub_content
)
if json_obj.get("hosting_instance_type_variants")
else None
)
self.hosting_additional_data_sources: Optional[JumpStartAdditionalDataSources] = (
JumpStartAdditionalDataSources(json_obj["hosting_additional_data_sources"])
if json_obj.get("hosting_additional_data_sources")
else None
)
self.hosting_neuron_model_id: Optional[str] = json_obj.get("hosting_neuron_model_id")
self.hosting_neuron_model_version: Optional[str] = json_obj.get(
"hosting_neuron_model_version"
)
if self.training_supported:
if self._is_hub_content:
self.training_ecr_uri: Optional[str] = json_obj.get("training_ecr_uri")
self._non_serializable_slots.append("training_ecr_specs")
else:
self.training_ecr_specs: Optional[JumpStartECRSpecs] = (
JumpStartECRSpecs(json_obj["training_ecr_specs"])
if "training_ecr_specs" in json_obj
else None
)
self._non_serializable_slots.append("training_ecr_uri")
self.training_artifact_key: str = json_obj["training_artifact_key"]
self.training_script_key: str = json_obj["training_script_key"]
hyperparameters: Any = json_obj.get("hyperparameters")
self.hyperparameters: List[JumpStartHyperparameter] = []
if hyperparameters is not None:
self.hyperparameters.extend(
[
JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content)
for hyperparameter in hyperparameters
]
)
self.estimator_kwargs = deepcopy(json_obj.get("estimator_kwargs", {}))
self.fit_kwargs = deepcopy(json_obj.get("fit_kwargs", {}))
self.training_volume_size: Optional[int] = json_obj.get("training_volume_size")
self.training_enable_network_isolation: bool = json_obj.get(
"training_enable_network_isolation", False
)
self.training_model_package_artifact_uris: Optional[Dict] = json_obj.get(
"training_model_package_artifact_uris"
)
self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
JumpStartInstanceTypeVariants(
json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content
)
if json_obj.get("training_instance_type_variants")
else None
)
self.model_subscription_link = json_obj.get("model_subscription_link")
self.default_training_dataset_key: Optional[str] = json_obj.get(
"default_training_dataset_key"
)
self.default_training_dataset_uri: Optional[str] = json_obj.get(
"default_training_dataset_uri"
)