in community-content/vertex_model_garden/model_oss/peft/handler.py [0:0]
def initialize(self, context: Any):
"""Initializes the handler."""
logging.info("Start to initialize the PEFT handler.")
properties = context.system_properties
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.manifest = context.manifest
self.precision_mode = os.environ.get(
"PRECISION_LOADING_MODE", constants.PRECISION_MODE_16
)
self.task = os.environ.get("TASK", CAUSAL_LANGUAGE_MODELING_LORA)
trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", None)
if trust_remote_code == "false":
self.trust_remote_code = False
else:
self.trust_remote_code = True
# If present, the path of the model in the container.
aip_storage_dir = os.environ.get("AIP_STORAGE_DIR", None)
# If present, the URI of the model in a google owned GCS bucket.
aip_storage_uri = os.environ.get("AIP_STORAGE_URI", None)
model_id = os.environ.get("MODEL_ID", None)
base_model_id = os.environ.get("BASE_MODEL_ID", None)
self.model_id = None
if aip_storage_dir:
self.model_id = aip_storage_dir
logging.info(f"Loaded base model from AIP_STORAGE_DIR: {self.model_id}.")
elif aip_storage_uri:
self.model_id = aip_storage_uri
logging.info(f"Loaded base model from AIP_STORAGE_URI: {self.model_id}.")
elif model_id:
self.model_id = model_id
logging.info(f"Loaded base model from MODEL_ID: {self.model_id}.")
elif base_model_id:
# Note: BASE_MODEL_ID has been unified with MODEL_ID.
# MODEL_ID should be used whenever possible.
self.model_id = base_model_id
logging.info(f"Loaded base model from BASE_MODEL_ID: {self.model_id}.")
self.quantization = os.environ.get("QUANTIZATION", None)
if not self.model_id:
raise ValueError("Base model id is must be set.")
if fileutils.is_gcs_path(self.model_id):
fileutils.download_gcs_dir_to_local(
self.model_id,
constants.LOCAL_BASE_MODEL_DIR,
skip_hf_model_bin=True,
)
self.model_id = constants.LOCAL_BASE_MODEL_DIR
self.finetuned_lora_model_path = os.environ.get(
"FINETUNED_LORA_MODEL_PATH", ""
)
if fileutils.is_gcs_path(self.finetuned_lora_model_path):
fileutils.download_gcs_dir_to_local(
self.finetuned_lora_model_path, constants.LOCAL_MODEL_DIR
)
self.finetuned_lora_model_path = constants.LOCAL_MODEL_DIR
logging.info(
f"Using task:{self.task}, base model:{self.model_id}, lora model:"
f" {self.finetuned_lora_model_path}, precision {self.precision_mode}."
)
self.pipeline = None
self.model = None
self.tokenizer = None
start_time = time.perf_counter()
logging.info("Started PEFT handler initialization at: %s", start_time)
if self.task == TEXT_TO_IMAGE_LORA:
pipeline = StableDiffusionPipeline.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
logging.debug("Initialized the base model for text to image.")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config
)
logging.debug("Initialized the scheduler for text to image.")
# This is to reduce GPU memory requirements.
pipeline.enable_xformers_memory_efficient_attention()
pipeline = pipeline.to(self.map_location)
# Reduces memory footprint.
pipeline.enable_attention_slicing()
if self.finetuned_lora_model_path:
pipeline.load_lora_weights(self.finetuned_lora_model_path)
logging.debug("Initialized the LoRA model for text to image.")
self.pipeline = pipeline
logging.info("Initialized the text to image pipelines.")
elif self.task == SEQUENCE_CLASSIFICATION_LORA:
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
logging.debug("Initialized the tokenizer for sequence classification.")
model = AutoModelForSequenceClassification.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
logging.debug("Initialized the base model for sequence classification.")
if self.finetuned_lora_model_path:
model = PeftModel.from_pretrained(model, self.finetuned_lora_model_path)
logging.debug("Initialized the LoRA model for sequence classification.")
model.to(self.map_location)
self.model = model
self.tokenizer = tokenizer
elif (
self.task == CAUSAL_LANGUAGE_MODELING_LORA or self.task == INSTRUCT_LORA
):
tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
trust_remote_code=self.trust_remote_code,
)
logging.debug("Initialized the tokenizer.")
if self.task == CAUSAL_LANGUAGE_MODELING_LORA:
if self.quantization == constants.AWQ:
model = AutoAWQForCausalLM.from_quantized(
self.model_id,
trust_remote_code=self.trust_remote_code,
)
elif self.quantization == constants.GPTQ or not self.quantization:
if self.precision_mode == constants.PRECISION_MODE_32:
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
return_dict=True,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=self.trust_remote_code,
)
elif self.precision_mode == constants.PRECISION_MODE_16B:
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
return_dict=True,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=self.trust_remote_code,
)
elif self.precision_mode == constants.PRECISION_MODE_16:
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=self.trust_remote_code,
)
elif self.precision_mode == constants.PRECISION_MODE_8:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, int8_threshold=0
)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=quantization_config,
trust_remote_code=self.trust_remote_code,
)
else:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
return_dict=True,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
trust_remote_code=self.trust_remote_code,
)
else:
raise ValueError(f"Invalid QUANTIZATION value: {self.quantization}")
else:
try:
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=self.trust_remote_code,
device_map="auto",
)
except: # pylint: disable=bare-except
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=self.trust_remote_code,
device_map="auto",
)
logging.debug("Initialized the base model.")
if self.finetuned_lora_model_path:
model = PeftModel.from_pretrained(model, self.finetuned_lora_model_path)
logging.debug("Initialized the LoRA model.")
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
self.tokenizer = tokenizer
self.pipeline = pipeline
else:
raise ValueError(f"Invalid TASK: {self.task}")
self.initialized = True
end_time = time.perf_counter()
logging.info("The PEFT handler was initialize at: %s", end_time)
logging.info("Handler initiation took %s seconds", end_time - start_time)