in src/neo_loader/mxnet_model_loader.py [0:0]
def __get_param_file_from_model_artifact(self) -> Path:
param_files = self._get_files_from_model_artifacts_with_extensions(["params"], exclude_files=self.SAGEMAKER_AUXILIARY_JSON_FILES)
symbol_file = self.__get_symbol_file_from_model_artifact()
symbol_prefix = self.__get_symbol_file_prefix(symbol_file)
target_params = list(filter(lambda file: file.name.startswith(symbol_prefix), param_files))
if not param_files:
raise RuntimeError("InputConfiguration: No parameter file found for MXNet model. "
"Please make sure the framework you select is correct.")
elif not target_params:
raise RuntimeError(f"InputConfiguration: No parameter file found for MXNet model: {symbol_file.as_posix()} "
"Please make sure the prefix of params file match the prefix of symbol file.")
elif len(target_params) > 1:
select_param_file = target_params[0]
latest_checkpoint = int(select_param_file.name[-11:-7])
for param_file in target_params:
checkpoint = int(param_file.name[-11:-7])
if checkpoint > latest_checkpoint:
select_param_file = param_file
latest_checkpoint = checkpoint
logger.warning(f"InputConfiguration: Multiple parameter files found for MXNet model. "
f"Parameter file: {select_param_file.as_posix()} will be used.")
return select_param_file
else:
return target_params[0]