in src/neo_loader/mxnet_model_loader.py [0:0]
def __get_symbol_file_from_model_artifact(self) -> Path:
symbol_files = self._get_files_from_model_artifacts_with_extensions(["json"],
exclude_files=self.SAGEMAKER_AUXILIARY_JSON_FILES + self.AMBARELLA_CONFIG_JSON_FILES + self.METADATA_AUTOPILOT_JSON_FILES)
if len(symbol_files) == 1:
return symbol_files[0]
if not symbol_files:
raise RuntimeError("InputConfiguration: No symbol file found for MXNet model. "
"Please make sure the framework you select is correct.")
if len(symbol_files) > 1:
# support SageMaker AP cross-validaiton, which has multiple models,
# fetch first model that matches prefix;
fpath = list(filter(lambda file: file.name.startswith(AP_CROSS_VALIDATION_PREFIX), symbol_files))
if len(fpath) == 1:
return fpath[0]
else:
raise RuntimeError("InputConfiguration: Only one symbol file is allowed for MXNet model. "
"Please make sure the framework you select is correct.")