def get_autoencoder()

in neuron_explainer/models/autoencoder_context.py [0:0]


    def get_autoencoder(self, layer_index: LayerIndex) -> Autoencoder:
        autoencoder_azure_path = self.autoencoder_config.autoencoder_path_by_layer_index.get(
            layer_index
        )
        if autoencoder_azure_path is None:
            raise ValueError(f"No autoencoder path for layer_index {layer_index}")
        else:
            if autoencoder_azure_path in self._cached_autoencoders_by_path:
                autoencoder = self._cached_autoencoders_by_path[autoencoder_azure_path]
            else:
                # Check if the autoencoder is cached on disk
                disk_cache_path = os.path.join(
                    "/tmp", autoencoder_azure_path.replace("https://", "")
                )
                if file_exists(disk_cache_path):
                    print(f"Loading autoencoder from disk cache: {disk_cache_path}")
                else:
                    print(f"Reading autoencoder from blob storage: {autoencoder_azure_path}")
                    copy_to_local_cache(autoencoder_azure_path, disk_cache_path)

                state_dict = torch.load(disk_cache_path, map_location=self.device)
                # released autoencoders are saved as a dict for better compatibility
                assert isinstance(state_dict, dict)
                autoencoder = Autoencoder.from_state_dict(state_dict, strict=False).to(self.device)
                self._cached_autoencoders_by_path[autoencoder_azure_path] = autoencoder

            # freeze the autoencoder
            for p in autoencoder.parameters():
                p.requires_grad = False
            return autoencoder