def generate_embeddings()

in sdk/python/foundation-models/healthcare-ai/medimageinsight/classification_demo/MedImageInsight.py [0:0]


    def generate_embeddings(self, data):
        """
        Runs inference to generate embeddings on the model.

        Parameters (Must provide one of the following):
        - data (dict):
            - 'image': The path or data of the image(s).
            - 'text': The text data.

        Returns:
        - embeddings_dict (dict): A dictionary where each key is the name, and the value is another dictionary containing 'image_feature' and/or 'text_feature'.
        """

        embeddings_dict = {}

        # Determine the appropriate function to call based on the option
        if self.option == "run_from_endpoint":
            run_function = self.run_from_endpoint
        elif self.option == "run_local":
            run_function = self.run_from_mlflow
        else:
            raise ValueError(
                f"Invalid option '{self.option}'. Expected 'run_from_endpoint' or 'run_local'."
            )

        # Flags to check if image and/or text data are provided
        has_image = data.get("image") is not None
        has_text = data.get("text") is not None
        params = data["params"] if "params" in data else None

        # Generate embeddings based on provided data
        if has_image and has_text:
            embedding_dict, scale_factor = run_function(
                image=data["image"], text=data["text"], params=params
            )
            for name, feat in embedding_dict.items():
                embeddings_dict.setdefault(name, {})["image_feature"] = feat[
                    "image_feature"
                ]
                embeddings_dict.setdefault(name, {})["text_feature"] = feat[
                    "image_feature"
                ]
        else:
            if has_image:
                image_embedding_dict, scale_factor = run_function(
                    image=data["image"], params=params
                )
                for name, img_feat in image_embedding_dict.items():
                    embeddings_dict.setdefault(name, {})["image_feature"] = img_feat[
                        "image_feature"
                    ]
            if has_text:
                text_embedding_dict, scale_factor = run_function(text=data["text"])
                for name, txt_feat in text_embedding_dict.items():
                    embeddings_dict.setdefault(name, {})["text_feature"] = txt_feat[
                        "text_feature"
                    ]

            return embeddings_dict, scale_factor