def _init_from_config_dict()

in activemri/envs/envs.py [0:0]


    def _init_from_config_dict(self, cfg: Mapping[str, Any]):
        self._cfg = cfg
        self._data_location = cfg["data_location"]
        if not os.path.isdir(self._data_location):
            default_cfg, defaults_fname = activemri.envs.util.get_defaults_json()
            self._data_location = default_cfg["data_location"]
            if not os.path.isdir(self._data_location) and self._has_setup:
                raise RuntimeError(
                    f"No 'data_location' key found in the given config. Please "
                    f"write dataset location in your JSON config, or in file {defaults_fname} "
                    f"(to use as a default)."
                )
        self._device = torch.device(cfg["device"])
        self.reward_metric = cfg["reward_metric"]
        if self.reward_metric not in ["mse", "ssim", "psnr", "nmse"]:
            raise ValueError("Reward metric must be one of mse, nmse, ssim, or psnr.")
        mask_func = activemri.envs.util.import_object_from_str(cfg["mask"]["function"])
        self._mask_func = functools.partial(mask_func, cfg["mask"]["args"])

        # Instantiating reconstructor
        reconstructor_cfg = cfg["reconstructor"]
        reconstructor_cls = activemri.envs.util.import_object_from_str(
            reconstructor_cfg["cls"]
        )

        checkpoint_fname = pathlib.Path(reconstructor_cfg["checkpoint_fname"])
        default_cfg, defaults_fname = activemri.envs.util.get_defaults_json()
        saved_models_dir = default_cfg["saved_models_dir"]
        checkpoint_path = pathlib.Path(saved_models_dir) / checkpoint_fname
        if self._has_setup and not checkpoint_path.is_file():
            raise RuntimeError(
                f"No checkpoint was found at {str(checkpoint_path)}. "
                f"Please make sure that both 'checkpoint_fname' (in your JSON config) "
                f"and 'saved_models_dir' (in {defaults_fname}) are configured correctly."
            )

        checkpoint = (
            torch.load(str(checkpoint_path)) if checkpoint_path.is_file() else None
        )
        options = reconstructor_cfg["options"]
        if checkpoint and "options" in checkpoint:
            msg = (
                f"Checkpoint at {checkpoint_path.name} has an 'options' key. "
                f"This will override the options defined in configuration file."
            )
            warnings.warn(msg)
            options = checkpoint["options"]
            assert isinstance(options, dict)
        self._reconstructor = reconstructor_cls(**options)
        self._reconstructor.init_from_checkpoint(checkpoint)
        self._reconstructor.eval()
        self._reconstructor.to(self._device)
        self._transform = activemri.envs.util.import_object_from_str(
            reconstructor_cfg["transform"]
        )