in neuron_explainer/models/autoencoder_context.py [0:0]
def get_autoencoder(self, layer_index: LayerIndex) -> Autoencoder:
autoencoder_azure_path = self.autoencoder_config.autoencoder_path_by_layer_index.get(
layer_index
)
if autoencoder_azure_path is None:
raise ValueError(f"No autoencoder path for layer_index {layer_index}")
else:
if autoencoder_azure_path in self._cached_autoencoders_by_path:
autoencoder = self._cached_autoencoders_by_path[autoencoder_azure_path]
else:
# Check if the autoencoder is cached on disk
disk_cache_path = os.path.join(
"/tmp", autoencoder_azure_path.replace("https://", "")
)
if file_exists(disk_cache_path):
print(f"Loading autoencoder from disk cache: {disk_cache_path}")
else:
print(f"Reading autoencoder from blob storage: {autoencoder_azure_path}")
copy_to_local_cache(autoencoder_azure_path, disk_cache_path)
state_dict = torch.load(disk_cache_path, map_location=self.device)
# released autoencoders are saved as a dict for better compatibility
assert isinstance(state_dict, dict)
autoencoder = Autoencoder.from_state_dict(state_dict, strict=False).to(self.device)
self._cached_autoencoders_by_path[autoencoder_azure_path] = autoencoder
# freeze the autoencoder
for p in autoencoder.parameters():
p.requires_grad = False
return autoencoder