def model_fn()

in source/inference.py [0:0]


def model_fn(model_dir):
    """
    Deserialize and load D2 model. This method is called automatically by Sagemaker.
    model_dir is location where your trained model will be downloaded.
    """
    
    try:
        config_path = os.path.join(MODEL_EXTRACT_LOC, 'config.yaml')
        
        for file in os.listdir(model_dir):
            logger.info(file)
            if file.endswith(".pth") or file.endswith(".pkl"):
                model_path = os.path.join(model_dir, file)

        logger.info(f"Using config file {config_path}")
        logger.info(f"Using model weights from {model_path}")            

        pred = _get_predictor(config_path, model_path)
        
    except Exception as e:
        logger.error("Model deserialization failed...")
        logger.error(e)  
        
    logger.info("Deserialization completed ...")
    
    return pred