in community-content/vertex_model_garden/model_oss/transformers/handler.py [0:0]
def initialize(self, context: Any):
"""Custom initialize."""
properties = context.system_properties
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.manifest = context.manifest
# The model id is can be either:
# 1) a huggingface model card id, like "Salesforce/blip", or
# 2) a GCS path to the model files, like "gs://foo/bar".
# If it's a model card id, the model will be loaded from huggingface.
self.model_id = (
DEFAULT_MODEL_ID
if os.environ.get("MODEL_ID") is None
else os.environ["MODEL_ID"]
)
# Else it will be downloaded from GCS to local first.
# Since the transformers from_pretrained API can't read from GCS.
if self.model_id.startswith(constants.GCS_URI_PREFIX):
gcs_path = self.model_id[len(constants.GCS_URI_PREFIX) :]
local_model_dir = os.path.join(constants.LOCAL_MODEL_DIR, gcs_path)
logging.info("Download %s to %s", self.model_id, local_model_dir)
fileutils.download_gcs_dir_to_local(self.model_id, local_model_dir)
self.model_id = local_model_dir
self.task = (
ZERO_CLASSIFICATION
if os.environ.get("TASK") is None
else os.environ["TASK"]
)
logging.info(
"Handler initializing task:%s, model:%s", self.task, self.model_id
)
if SALESFORCE_BLIP in self.model_id:
# pipeline() hasn't been ready for Salesforce/blip models.
self.salesforce_blip = True
self._create_blip_model()
else:
self.salesforce_blip = False
if self.task == FEATURE_EMBEDDING:
self.model = CLIPModel.from_pretrained(self.model_id).to(
self.map_location
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.processor = AutoProcessor.from_pretrained(self.model_id)
elif self.task == SUMMARIZATION and FLAN_T5 in self.model_id:
self.pipeline = pipeline(
task=self.task,
model=self.model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
else:
self.pipeline = pipeline(
task=self.task, model=self.model_id, device=self.device
)
self.initialized = True
logging.info("Handler initialization done.")