in src/sagemaker_huggingface_inference_toolkit/handler_service.py [0:0]
def validate_and_initialize_user_module(self):
"""Retrieves and validates the inference handlers provided within the user module.
Can override load, preprocess, predict and post process function.
"""
user_module_name = self.environment.module_name
if importlib.util.find_spec(user_module_name) is not None:
logger.info("Inference script implementation found at `{}`.".format(user_module_name))
user_module = importlib.import_module(user_module_name)
load_fn = getattr(user_module, MODEL_FN, None)
preprocess_fn = getattr(user_module, INPUT_FN, None)
predict_fn = getattr(user_module, PREDICT_FN, None)
postprocess_fn = getattr(user_module, OUTPUT_FN, None)
transform_fn = getattr(user_module, TRANSFORM_FN, None)
if transform_fn and (preprocess_fn or predict_fn or postprocess_fn):
raise ValueError(
"Cannot use {} implementation in conjunction with {}, {}, and/or {} implementation".format(
TRANSFORM_FN, INPUT_FN, PREDICT_FN, OUTPUT_FN
)
)
self.log_func_implementation_found_or_not(load_fn, MODEL_FN)
if load_fn is not None:
self.load_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.load, load_fn)
self.load = load_fn
self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN)
if preprocess_fn is not None:
self.preprocess_extra_arg = self.function_extra_arg(
HuggingFaceHandlerService.preprocess, preprocess_fn
)
self.preprocess = preprocess_fn
self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN)
if predict_fn is not None:
self.predict_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.predict, predict_fn)
self.predict = predict_fn
self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN)
if postprocess_fn is not None:
self.postprocess_extra_arg = self.function_extra_arg(
HuggingFaceHandlerService.postprocess, postprocess_fn
)
self.postprocess = postprocess_fn
self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN)
if transform_fn is not None:
self.transform_extra_arg = self.function_extra_arg(
HuggingFaceHandlerService.transform_fn, transform_fn
)
self.transform_fn = transform_fn
else:
logger.info(
"No inference script implementation was found at `{}`. Default implementation of all functions will be used.".format(
user_module_name
)
)