community-content/vertex_model_garden/model_oss/pic2word/handler.py (122 lines of code) (raw):
"""Custom handler for Pic2Word."""
from argparse import Namespace # pylint: disable=g-importing-member
import os
from typing import Any, List
from absl import logging
from data import CustomFolder
from eval_utils import visualize_results
from model.clip import load
from model.model import convert_weights
from model.model import IM2TEXT
from params import get_project_root
import torch
from torch.utils.data import DataLoader
from ts.torch_handler.base_handler import BaseHandler
from util import fileutils
# The COCO dataset is stored in a publicly accessible bucket.
_COCO_STORAGE_DIR = "gs://pic2word-bucket/data/coco/"
_COCO_LOCAL_DIR = "/home/model-server/composed_image_retrieval/data/coco/"
_COCO_VAL2017_PATH = "coco/val2017"
_COCO_DATASET_NAME = "coco"
_MODEL_NAME = "ViT-L/14"
_LOCAL_QUERY_PATH = "./query/"
_IMAGE_OUTPUT_LOCAL_DIR = "demo_out/images"
_OUTPUT_LOCAL_DIR = "./demo_out/"
_DATA_DIR = "data"
_CHECKPOINT_DIR = "checkpoint/pic2word_model.pt"
_REQUEST_PROMPTS = "prompts"
_REQUEST_OUTPUT_STORAGE_DIR = "output_storage_dir"
_REQUEST_IMAGE_PATH = "image_path"
_REQUEST_IMAGE_FILE_NAME = "image_file_name"
_RESPONSE_MSG = "Successfully retrieved images."
_PICKLE_DIR_PATH = "gs://pic2word-bucket/pickle/"
class ModelHandler(BaseHandler):
"""A custom model handler implementation."""
def __init__(self):
self.initialized = False
self.gpu = 0
self.model = None
self.dataloader = None
self.prompt = None
self.output_storage_dir = None
def initialize(self, context: Any):
"""Initialize."""
logging.info("Initializing pic2word.")
# Download pickle file for COCO
fileutils.download_gcs_dir_to_local(_PICKLE_DIR_PATH, "./data")
# Download COCO dataset. The model looks for this folder specifically
# during image retrieval to generate a response for each request.
# This is a publicly accessible bucket.
fileutils.download_gcs_dir_to_local(
_COCO_STORAGE_DIR,
_COCO_LOCAL_DIR,
)
# Load the model.
self.initialized = True
torch.cuda.set_device(self.gpu)
model, _, preprocess_val = load(_MODEL_NAME, jit=False)
img2text = IM2TEXT(
embed_dim=model.embed_dim,
output_dim=model.token_embedding.weight.shape[1],
)
model.cuda(self.gpu)
img2text.cuda(self.gpu)
convert_weights(model)
convert_weights(img2text)
self.model = model
self.img2text = img2text
# Load the dataset
logging.info("Loading dataset.")
root_project = os.path.join(get_project_root(), _DATA_DIR)
dataset = CustomFolder(
os.path.join(root_project, _COCO_VAL2017_PATH), transform=preprocess_val
)
# Initialize the dataloader. This is used to create the pickle file from
# the dataset.
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=False,
num_workers=1,
pin_memory=True,
drop_last=False,
)
self.dataloader = dataloader
logging.info("Finished initializing Pic2Word server.")
def preprocess(self, data: Any) -> str:
"""Preprocess input data."""
logging.info("Preprocessing Pic2Word inference request.")
query = data[0]
self.output_storage_dir = query[_REQUEST_OUTPUT_STORAGE_DIR]
prompts = query[_REQUEST_PROMPTS]
prompts = prompts.split(",")
self.prompt = prompts
image_path = query[_REQUEST_IMAGE_PATH]
# The query image is only supported via GCS bucket upload.
fileutils.download_gcs_dir_to_local(image_path, _LOCAL_QUERY_PATH)
image_file_name = query[_REQUEST_IMAGE_FILE_NAME]
query_file = f"./query/{image_file_name}"
logging.info("Setting model args.")
args = {
"openai-pretrained": True,
"resume": _CHECKPOINT_DIR,
"retrieval_data": _COCO_DATASET_NAME,
"query_file": query_file,
"demo_out": _OUTPUT_LOCAL_DIR,
"prompts": prompts,
"distributed": False,
"dp": False,
"gpu": 0,
"model": _MODEL_NAME,
"world_size": 1,
}
model_input = Namespace(**args)
logging.info("Finished preprocessing Pic2Word inference request.")
return model_input
def inference(self, model_input: Any):
"""Runs inference."""
logging.info("Running model-inference.")
visualize_results(
model=self.model,
img2text=self.img2text,
args=model_input,
prompt=self.prompt,
dataloader=self.dataloader,
)
def postprocess(self):
"""Upload the output images to the bucket."""
logging.info("Running request postprocess.")
fileutils.upload_local_dir_to_gcs(
_IMAGE_OUTPUT_LOCAL_DIR, self.output_storage_dir
)
def handle(self, data: Any, context: Any) -> List[str]: # pylint: disable=unused-argument
"""Runs preprocess, inference, and post-processing."""
logging.info("Received Pic2Word inference request")
model_input = self.preprocess(data)
self.inference(model_input)
self.postprocess()
logging.info("Done handling input.")
return [_RESPONSE_MSG]