in community-content/vertex_model_garden/model_oss/diffusers/handler.py [0:0]
def initialize(self, context: Any):
"""Custom initialize."""
properties = context.system_properties
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.manifest = context.manifest
self.model_id = os.environ["MODEL_ID"]
if self.model_id.startswith(constants.GCS_URI_PREFIX):
gcs_path = self.model_id[len(constants.GCS_URI_PREFIX) :]
local_model_dir = os.path.join(constants.LOCAL_MODEL_DIR, gcs_path)
logging.info(f"Download {self.model_id} to {local_model_dir}")
fileutils.download_gcs_dir_to_local(self.model_id, local_model_dir)
self.model_id = local_model_dir
self.task = os.environ.get("TASK", TEXT_TO_IMAGE)
logging.info(f"Using task:{self.task}, model:{self.model_id}")
if self.task == TEXT_TO_IMAGE:
pipeline = StableDiffusionPipeline.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config
)
pipeline = pipeline.to(self.map_location)
# Reduce memory footprint.
pipeline.enable_attention_slicing()
elif self.task == IMAGE_TO_IMAGE:
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config
)
pipeline = pipeline.to(self.map_location)
# Reduce memory footprint.
pipeline.enable_attention_slicing()
elif self.task == IMAGE_INPAINTING:
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config
)
pipeline = pipeline.to(self.map_location)
# Reduce memory footprint.
pipeline.enable_attention_slicing()
elif self.task == INSTRUCT_PIX2PIX:
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config
)
pipeline = pipeline.to(self.map_location)
# Reduce memory footprint.
pipeline.enable_attention_slicing()
elif self.task == CONTROLNET:
controlnet = ControlNetModel.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
STABLE_DIFFUSION_MODEL,
controlnet=controlnet,
torch_dtype=torch.float16,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(
pipeline.scheduler.config
)
pipeline.enable_xformers_memory_efficient_attention()
pipeline.enable_model_cpu_offload()
pipeline = pipeline.to(self.map_location)
# Reduce memory footprint.
pipeline.enable_attention_slicing()
elif self.task == CONDITIONED_SUPER_RES:
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config
)
# This is necessary to 4x upscale >=256x256 input images with V100.
logging.info("Enable xformers memory efficient attention for inference.")
pipeline.enable_xformers_memory_efficient_attention()
pipeline = pipeline.to(self.map_location)
# Reduce memory footprint.
pipeline.enable_attention_slicing()
elif self.task == TEXT_TO_VIDEO_ZERO_SHOT:
pipeline = TextToVideoZeroPipeline.from_pretrained(
STABLE_DIFFUSION_MODEL, torch_dtype=torch.float16
)
# Memory optimization.
pipeline.enable_xformers_memory_efficient_attention()
pipeline.enable_model_cpu_offload()
pipeline = pipeline.to(self.map_location)
elif self.task == TEXT_TO_VIDEO:
pipeline = DiffusionPipeline.from_pretrained(
self.model_id, torch_dtype=torch.float16, variant="fp16"
)
pipeline.enable_model_cpu_offload()
# Memory optimization.
pipeline.enable_vae_slicing()
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config
)
else:
raise ValueError(f"Invalid TASK: {self.task}")
self.pipeline = pipeline
self.initialized = True
logging.info("Handler initialization done.")