in src/sagemaker/jumpstart/hub/interfaces.py [0:0]
def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json.
Args:
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
"""
self.url: str = json_obj.get("Url")
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
self.hosting_script_uri = json_obj.get("HostingScriptUri")
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
]
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
self.incremental_training_supported: bool = bool(
json_obj.get("IncrementalTrainingSupported")
)
self.dynamic_container_deployment_supported: Optional[bool] = (
bool(json_obj.get("DynamicContainerDeploymentSupported"))
if json_obj.get("DynamicContainerDeploymentSupported")
else None
)
self.hosting_artifact_s3_data_type: Optional[str] = json_obj.get(
"HostingArtifactS3DataType"
)
self.hosting_artifact_compression_type: Optional[str] = json_obj.get(
"HostingArtifactCompressionType"
)
self.hosting_prepacked_artifact_uri: Optional[str] = json_obj.get(
"HostingPrepackedArtifactUri"
)
self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get(
"HostingPrepackedArtifactVersion"
)
self.hosting_use_script_uri: Optional[bool] = (
bool(json_obj.get("HostingUseScriptUri"))
if json_obj.get("HostingUseScriptUri") is not None
else None
)
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")
self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion")
self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")
self.inference_config_rankings = self._get_config_rankings(json_obj)
self.inference_config_components = self._get_config_components(json_obj)
self.inference_configs = self._get_configs(json_obj)
self.default_inference_instance_type: Optional[str] = json_obj.get(
"DefaultInferenceInstanceType"
)
self.supported_inference_instance_types: Optional[str] = json_obj.get(
"SupportedInferenceInstanceTypes"
)
self.sage_maker_sdk_predictor_specifications: Optional[JumpStartPredictorSpecs] = (
JumpStartPredictorSpecs(
json_obj.get("SageMakerSdkPredictorSpecifications"),
is_hub_content=True,
)
if json_obj.get("SageMakerSdkPredictorSpecifications")
else None
)
self.inference_volume_size: Optional[int] = json_obj.get("InferenceVolumeSize")
self.inference_enable_network_isolation: Optional[str] = json_obj.get(
"InferenceEnableNetworkIsolation", False
)
self.fine_tuning_supported: Optional[bool] = (
bool(json_obj.get("FineTuningSupported"))
if json_obj.get("FineTuningSupported")
else None
)
self.validation_supported: Optional[bool] = (
bool(json_obj.get("ValidationSupported"))
if json_obj.get("ValidationSupported")
else None
)
self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase")
self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False))
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
{
alias: JumpStartSerializablePayload(payload, is_hub_content=True)
for alias, payload in json_obj.get("DefaultPayloads").items()
}
if json_obj.get("DefaultPayloads")
else None
)
self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get(
"HostingResourceRequirements", None
)
self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
JumpStartInstanceTypeVariants(
json_obj.get("HostingInstanceTypeVariants"),
is_hub_content=True,
)
if json_obj.get("HostingInstanceTypeVariants")
else None
)
self.notebook_location_uris: Optional[NotebookLocationUris] = (
NotebookLocationUris(json_obj.get("NotebookLocationUris"))
if json_obj.get("NotebookLocationUris")
else None
)
self.model_provider_icon_uri: Optional[str] = None # Not needed for private beta
self.task: Optional[str] = json_obj.get("Task")
self.framework: Optional[str] = json_obj.get("Framework")
self.datatype: Optional[str] = json_obj.get("Datatype")
self.license: Optional[str] = json_obj.get("License")
self.contextual_help: Optional[str] = json_obj.get("ContextualHelp")
self.model_dir: Optional[str] = json_obj.get("ModelDir")
# Deploy kwargs
self.model_data_download_timeout: Optional[str] = json_obj.get("ModelDataDownloadTimeout")
self.container_startup_health_check_timeout: Optional[str] = json_obj.get(
"ContainerStartupHealthCheckTimeout"
)
if self.training_supported:
self.default_training_dataset_uri: Optional[str] = json_obj.get(
"DefaultTrainingDatasetUri"
)
self.training_model_package_artifact_uri: Optional[str] = json_obj.get(
"TrainingModelPackageArtifactUri"
)
self.training_artifact_compression_type: Optional[str] = json_obj.get(
"TrainingArtifactCompressionType"
)
self.training_artifact_s3_data_type: Optional[str] = json_obj.get(
"TrainingArtifactS3DataType"
)
self.hyperparameters: List[JumpStartHyperparameter] = []
hyperparameters: Any = json_obj.get("Hyperparameters")
if hyperparameters is not None:
self.hyperparameters.extend(
[
JumpStartHyperparameter(hyperparameter, is_hub_content=True)
for hyperparameter in hyperparameters
]
)
self.training_script_uri: Optional[str] = json_obj.get("TrainingScriptUri")
self.training_prepacked_script_uri: Optional[str] = json_obj.get(
"TrainingPrepackedScriptUri"
)
self.training_prepacked_script_version: Optional[str] = json_obj.get(
"TrainingPrepackedScriptVersion"
)
self.training_ecr_uri: Optional[str] = json_obj.get("TrainingEcrUri")
self._non_serializable_slots.append("training_ecr_specs")
self.training_metrics: Optional[List[Dict[str, str]]] = json_obj.get(
"TrainingMetrics", None
)
self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri")
self.training_config_rankings = self._get_config_rankings(
json_obj, _ComponentType.TRAINING
)
self.training_config_components = self._get_config_components(
json_obj, _ComponentType.TRAINING
)
self.training_configs = self._get_configs(json_obj, _ComponentType.TRAINING)
self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies")
self.default_training_instance_type: Optional[str] = json_obj.get(
"DefaultTrainingInstanceType"
)
self.supported_training_instance_types: Optional[str] = json_obj.get(
"SupportedTrainingInstanceTypes"
)
self.training_volume_size: Optional[int] = json_obj.get("TrainingVolumeSize")
self.training_enable_network_isolation: Optional[str] = json_obj.get(
"TrainingEnableNetworkIsolation", False
)
self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
JumpStartInstanceTypeVariants(
json_obj.get("TrainingInstanceTypeVariants"),
is_hub_content=True,
)
if json_obj.get("TrainingInstanceTypeVariants")
else None
)
# Estimator kwargs
self.encrypt_inter_container_traffic: Optional[bool] = (
bool(json_obj.get("EncryptInterContainerTraffic"))
if json_obj.get("EncryptInterContainerTraffic")
else None
)
self.max_runtime_in_seconds: Optional[str] = json_obj.get("MaxRuntimeInSeconds")
self.disable_output_compression: Optional[bool] = (
bool(json_obj.get("DisableOutputCompression"))
if json_obj.get("DisableOutputCompression")
else None
)