in source/inference.py [0:0]
def model_fn(model_dir):
"""
Deserialize and load D2 model. This method is called automatically by Sagemaker.
model_dir is location where your trained model will be downloaded.
"""
try:
config_path = os.path.join(MODEL_EXTRACT_LOC, 'config.yaml')
for file in os.listdir(model_dir):
logger.info(file)
if file.endswith(".pth") or file.endswith(".pkl"):
model_path = os.path.join(model_dir, file)
logger.info(f"Using config file {config_path}")
logger.info(f"Using model weights from {model_path}")
pred = _get_predictor(config_path, model_path)
except Exception as e:
logger.error("Model deserialization failed...")
logger.error(e)
logger.info("Deserialization completed ...")
return pred