in activemri/envs/envs.py [0:0]
def _init_from_config_dict(self, cfg: Mapping[str, Any]):
self._cfg = cfg
self._data_location = cfg["data_location"]
if not os.path.isdir(self._data_location):
default_cfg, defaults_fname = activemri.envs.util.get_defaults_json()
self._data_location = default_cfg["data_location"]
if not os.path.isdir(self._data_location) and self._has_setup:
raise RuntimeError(
f"No 'data_location' key found in the given config. Please "
f"write dataset location in your JSON config, or in file {defaults_fname} "
f"(to use as a default)."
)
self._device = torch.device(cfg["device"])
self.reward_metric = cfg["reward_metric"]
if self.reward_metric not in ["mse", "ssim", "psnr", "nmse"]:
raise ValueError("Reward metric must be one of mse, nmse, ssim, or psnr.")
mask_func = activemri.envs.util.import_object_from_str(cfg["mask"]["function"])
self._mask_func = functools.partial(mask_func, cfg["mask"]["args"])
# Instantiating reconstructor
reconstructor_cfg = cfg["reconstructor"]
reconstructor_cls = activemri.envs.util.import_object_from_str(
reconstructor_cfg["cls"]
)
checkpoint_fname = pathlib.Path(reconstructor_cfg["checkpoint_fname"])
default_cfg, defaults_fname = activemri.envs.util.get_defaults_json()
saved_models_dir = default_cfg["saved_models_dir"]
checkpoint_path = pathlib.Path(saved_models_dir) / checkpoint_fname
if self._has_setup and not checkpoint_path.is_file():
raise RuntimeError(
f"No checkpoint was found at {str(checkpoint_path)}. "
f"Please make sure that both 'checkpoint_fname' (in your JSON config) "
f"and 'saved_models_dir' (in {defaults_fname}) are configured correctly."
)
checkpoint = (
torch.load(str(checkpoint_path)) if checkpoint_path.is_file() else None
)
options = reconstructor_cfg["options"]
if checkpoint and "options" in checkpoint:
msg = (
f"Checkpoint at {checkpoint_path.name} has an 'options' key. "
f"This will override the options defined in configuration file."
)
warnings.warn(msg)
options = checkpoint["options"]
assert isinstance(options, dict)
self._reconstructor = reconstructor_cls(**options)
self._reconstructor.init_from_checkpoint(checkpoint)
self._reconstructor.eval()
self._reconstructor.to(self._device)
self._transform = activemri.envs.util.import_object_from_str(
reconstructor_cfg["transform"]
)