in src/serve.py [0:0]
def model_fn(model_dir):
# Load model
model_files = list(glob.glob("{}/*.pt".format(model_dir)))
error_msg = "Expected exactly 1 model file (match pattern *.pt)in dir {}, but instead found {} files. Found.. {}".format(
model_dir, len(model_files), ",".join(model_files))
assert len(model_files) == 1, error_msg
model_file = model_files[0]
device = get_device()
model = torch.load(model_file, map_location=torch.device(device))
# Load label mapper
label_mapper_pickle_file = os.path.join(model_dir, "label_mapper.pkl")
with open(label_mapper_pickle_file, "rb") as f:
label_mapper = pickle.load(f)
# Load preprocessor
preprocessor_pickle_file = os.path.join(model_dir, "preprocessor.pkl")
with open(preprocessor_pickle_file, "rb") as f:
preprocessor_mapper = pickle.load(f)
return preprocessor_mapper, model, label_mapper