def model_fn()

in src/serve.py [0:0]


def model_fn(model_dir):
    # Load model
    model_files = list(glob.glob("{}/*.pt".format(model_dir)))
    error_msg = "Expected exactly 1 model file (match pattern *.pt)in dir {}, but instead found {} files. Found.. {}".format(
        model_dir, len(model_files), ",".join(model_files))
    assert len(model_files) == 1, error_msg

    model_file = model_files[0]
    device = get_device()
    model = torch.load(model_file, map_location=torch.device(device))

    # Load label mapper
    label_mapper_pickle_file = os.path.join(model_dir, "label_mapper.pkl")
    with open(label_mapper_pickle_file, "rb") as f:
        label_mapper = pickle.load(f)

    # Load preprocessor
    preprocessor_pickle_file = os.path.join(model_dir, "preprocessor.pkl")
    with open(preprocessor_pickle_file, "rb") as f:
        preprocessor_mapper = pickle.load(f)

    return preprocessor_mapper, model, label_mapper