in src/neo_loader/pytorch_model_loader.py [0:0]
def load_model(self) -> None:
logger.info("Generating relay IR for pytorch model!")
self.__extract_pth_file_from_model_artifact()
try:
trace = self.__get_pytorch_trace_from_model_artifact()
except Exception as e:
logger.warning("Failed to load pytorch model. %s" % repr(e))
msg = 'InputConfiguration: Framework cannot load PyTorch model. {}'.format(e)
try:
# for FCOS models
trace = torch.jit.load(self.__pth_file.as_posix(), map_location='cpu').float().eval()
self._relay_module_object, self._params = relay.frontend.from_pytorch(trace, self.data_shape)
self.update_missing_metadata()
except Exception as e:
logger.exception("Failed to load pytorch model. %s" % repr(e))
raise RuntimeError(msg)
else:
try:
self._relay_module_object, self._params = relay.frontend.from_pytorch(trace, self.data_shape)
self.update_missing_metadata()
except Exception as e:
logger.exception("Failed to convert pytorch model. %s" % repr(e))
msg = 'InputConfiguration: TVM cannot convert the PyTorch model. Invalid model or ' \
'input-shape mismatch. Make sure that inputs are lexically ordered and of ' \
'the correct dimensionality. {}'.format(e)
raise RuntimeError(msg)
else:
logger.info("Successfully generated relay IR for pytorch model!")