in community-content/vertex_model_garden/model_oss/transformers/sam_handler.py [0:0]
def initialize(self, context: Any):
"""Custom initialize."""
# vv-docker:google3-begin(internal)
# TODO(b/287051908): Move handler functions to common utils for
# everyone to use.
# vv-docker:google3-end
properties = context.system_properties
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.manifest = context.manifest
# The model id is can be either:
# 1) a huggingface model card id, like "Salesforce/blip", or
# 2) a GCS path to the model files, like "gs://foo/bar".
# If it's a model card id, the model will be loaded from huggingface.
self.model_id = (
DEFAULT_MODEL_ID
if os.environ.get("MODEL_ID") is None
else os.environ["MODEL_ID"]
)
# Else it will be downloaded from GCS to local first.
# Since the transformers from_pretrained API can't read from GCS.
if self.model_id.startswith(GCS_PREFIX):
gcs_path = self.model_id[len(GCS_PREFIX) :]
local_model_dir = os.path.join(DOWNLOAD_DIR, gcs_path)
logging.info(f"Download {self.model_id} to {local_model_dir}")
download_gcs_dir(self.model_id, local_model_dir)
self.model_id = local_model_dir
self.task = (
MASK_GENERATION
if os.environ.get("TASK") is None
else os.environ["TASK"]
)
logging.info(
f"Handler initializing task:{self.task}, model:{self.model_id}"
)
if self.task == MASK_GENERATION:
self.pipeline = pipeline(
task="mask-generation", model=self.model_id, device=self.device
)
else:
raise ValueError(f"Invalid TASK: {self.task}")
self.initialized = True
logging.info("Handler initialization done.")