in assets/training/model_evaluation/src/utils_load.py [0:0]
def load_model(model_uri, device, task):
"""Load model from given details.
Args:
model_uri (_type_): _description_
device (_type_): _description_
task (_type_): _description_
Returns:
model: _description_
"""
curr_model = Model.load(model_uri).flavors
model_flavor = ""
if "hftransformers" in curr_model:
curr_model = curr_model.get("hftransformers")
model_flavor = MODEL_FLAVOR.HFTRANSFORMERS
elif "hftransformersv2" in curr_model:
curr_model = curr_model.get("hftransformersv2")
model_flavor = MODEL_FLAVOR.HFTRANSFORMERSV2
elif "transformers" in curr_model:
model_flavor = MODEL_FLAVOR.TRANSFORMERS
else:
curr_model = {}
aml_args = {
"model_hf_load_kwargs": curr_model.get("model_hf_load_kwargs", {})
}
# Todo: Remove this once we have a fix for the issue
if model_flavor == MODEL_FLAVOR.TRANSFORMERS and constants.MLFLOW_MODEL_TYPE_MAP[task] == "summarization":
logger.info("setting OS for text-summarization task")
os.environ["MLFLOW_HUGGINGFACE_USE_DEVICE_MAP"] = "False"
if device == constants.DEVICE.AUTO and torch.cuda.is_available():
aml_args["model_hf_load_kwargs"]["device_map"] = constants.DEVICE.AUTO
elif device == constants.DEVICE.GPU and torch.cuda.is_available():
aml_args["model_hf_load_kwargs"]["device_map"] = torch.cuda.current_device()
os.environ["MLFLOW_DEFAULT_PREDICTION_DEVICE"] = str(torch.cuda.current_device())
elif isinstance(device, int) and device >= 0:
aml_args["model_hf_load_kwargs"]["device_map"] = device
os.environ["MLFLOW_DEFAULT_PREDICTION_DEVICE"] = str(device)
else:
if device == constants.DEVICE.GPU:
logger.warning("Device_map set as GPU, but the compute doesn't have a GPU.")
logger.info("Loading model on CPU with f32 dtype")
aml_args["model_hf_load_kwargs"]["device_map"] = constants.DEVICE.CPU
# Todo: Add equivalent for tramformers flavor as well
# Check what case fails here
aml_args["model_hf_load_kwargs"]["torch_dtype"] = torch.float32
os.environ["MLFLOW_DEFAULT_PREDICTION_DEVICE"] = str(-1)
try:
logger.info(f"aml args: {aml_args}")
if model_flavor != MODEL_FLAVOR.TRANSFORMERS:
logger.info("Loading model in hftransformers flavor")
model = aml_mlflow.aml.load_model(model_uri=model_uri,
model_type=constants.MLFLOW_MODEL_TYPE_MAP[task], **aml_args)
else:
logger.info("Loading model in mlflow transformers flavor")
model = mlflow.pyfunc.load_model(model_uri)
except Exception:
logger.info("Reloading model with device_map NA")
if model_flavor != MODEL_FLAVOR.TRANSFORMERS:
logger.info("Loading model in hftransformers flavor")
aml_args["model_hf_load_kwargs"]["device_map"] = "eval_na"
logger.info(f"aml args: {aml_args}")
model = aml_mlflow.aml.load_model(model_uri=model_uri,
model_type=constants.MLFLOW_MODEL_TYPE_MAP[task], **aml_args)
else:
os.environ["MLFLOW_HUGGINGFACE_USE_DEVICE_MAP"] = "False"
logger.info("Loading model in mlflow transformers flavor")
model = mlflow.pyfunc.load_model(model_uri)
return model, model_flavor