in utils_cv/detection/model.py [0:0]
def load(self, name: str = None, path: str = None) -> None:
""" Loads a model.
Loads a model that is saved in the format that is outputted in the
`save` function.
Args:
name: The name of the model you wish to load. If no name is
specified, the function will still look for a model under the path
specified by `data_path`. If multiple models are available in that
path, it will require you to pass in a name to specify which one to
use.
path: Pass in a path if the model is not located in the
`data_path`. Otherwise it will assume that it is.
Raise:
Exception if passed in name/path is invalid and doesn't exist
"""
# set path
if not path:
if self.dataset:
path = Path(self.dataset.root) / "models"
else:
raise Exception("Specify a `path` parameter")
# if name is given..
if name:
model_path = path / name
pt_path = model_path / "model.pt"
if not pt_path.exists():
raise Exception(
f"No model file named model.pt exists in {model_path}"
)
meta_path = model_path / "meta.json"
if not meta_path.exists():
raise Exception(
f"No model file named meta.txt exists in {model_path}"
)
# if no name is given, we assume there is only one model, otherwise we
# throw an error
else:
models = [f.path for f in os.scandir(path) if f.is_dir()]
if len(models) == 0:
raise Exception(f"No model found in {path}.")
elif len(models) > 1:
print(
f"Multiple models were found in {path}. Please specify which you wish to use in the `name` argument."
)
for model in models:
print(model)
exit()
else:
pt_path = Path(models[0]) / "model.pt"
meta_path = Path(models[0]) / "meta.json"
# load into model
self.model.load_state_dict(
torch.load(pt_path, map_location=torch_device())
)
# load meta info
with open(meta_path, "r") as meta_file:
meta_data = json.load(meta_file)
self.labels = meta_data["labels"]