kernels-mixer/kernels_mixer/kernels.py (114 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from jupyter_client.kernelspec import KernelSpecManager from jupyter_client.manager import in_pending_state from jupyter_core.utils import ensure_async, run_sync from jupyter_server.gateway.managers import GatewayMappingKernelManager from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager from jupyter_server.services.kernels.kernelmanager import ServerKernelManager from traitlets import Instance, default, observe from .kernelspecs import MixingKernelSpecManager class MixingMappingKernelManager(AsyncMappingKernelManager): kernel_spec_manager = Instance(KernelSpecManager) @default("kernel_spec_manager") def _default_kernel_spec_manager(self): return "kernels_mixer.kernelspecs.MixingKernelSpecManager" @observe("kernel_spec_manager") def _observe_kernel_spec_manager(self, change): self.log.debug(f"Configured kernel spec manager: {change.new}") if isinstance(change.new, MixingKernelSpecManager): return self.kernel_spec_manager = MixingKernelSpecManager(parent=change.new.parent) self.kernel_spec_manager.local_manager = change.new self.parent.kernel_spec_manager = self.kernel_spec_manager @default("kernel_manager_class") def _default_kernel_manager_class(self): return "kernels_mixer.kernels.MixingKernelManager" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.log.debug(f"Kernel spec manager: {self.kernel_spec_manager}") # Set up the local kernel management. self.local_manager = AsyncMappingKernelManager( parent=self.parent, log=self.log, connection_dir=self.connection_dir, kernel_spec_manager=self.kernel_spec_manager.local_manager) # Set up the remote kernel management. self.remote_manager = GatewayMappingKernelManager( parent=self.parent, log=self.log, connection_dir=self.connection_dir, kernel_spec_manager=self.kernel_spec_manager.remote_manager) def has_remote_kernels(self): for kid in self._kernels: if self._kernels[kid].is_remote: return True return False async def list_kernels(self): if self.has_remote_kernels(): # We have remote kernels, so we must call `list_kernels` on the # Gateway kernel manager to update our kernel models. try: await ensure_async(self.remote_manager.list_kernels()) except Exception as ex: self.log.exception('Failure listing remote kernels: %s', ex) # Ignore the exception listing remote kernels, so that local kernels are still usable. return super().list_kernels() def kernel_model(self, kernel_id): self._check_kernel_id(kernel_id) kernel = self._kernels[kernel_id] # Normally, calls to `run_sync` pose a danger of locking up Tornado's # single-threaded event loop. # # However, the call below should be fine because it cannot block for an # arbitrary amount of time. # # This call blocks on the `model` method defined below, which in turn # blocks on the `GatewayMappingKernelManager`'s `kernel_model` method # (https://github.com/jupyter-server/jupyter_server/blob/547f7a244d89f79dd09fa7d382322d1c40890a3f/jupyter_server/gateway/managers.py#L94). # # That will only take a small, deterministic amount of time to complete # because that `kernel_model` only operates on existing, in-memory data # and does not block on any outgoing network requests. return run_sync(kernel.model)() class MixingKernelManager(ServerKernelManager): _kernel_id_map = {} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @property def is_remote(self): if not self.kernel_name or not self.kernel_id: return False return self.parent.kernel_spec_manager.is_remote(self.kernel_name) @property def delegate_kernel_id(self): if not self.kernel_id: return None return MixingKernelManager._kernel_id_map.get(self.kernel_id, None) @property def delegate_multi_kernel_manager(self): if self.is_remote: return self.parent.remote_manager return self.parent.local_manager @property def delegate(self): if not self.kernel_name or not self.kernel_id: return None return self.delegate_multi_kernel_manager.get_kernel(self.delegate_kernel_id) @property def has_kernel(self): delegate = self.delegate if not delegate: return false return delegate.has_kernel def client(self, *args, **kwargs): delegate = self.delegate if not delegate: return None return delegate.client(*args, **kwargs) @in_pending_state async def start_kernel(self, *args, **kwargs): self.kernel_name = kwargs.get("kernel_name", self.kernel_name) kernel_id = kwargs.pop("kernel_id", self.kernel_id) if kernel_id: self.kernel_id = kernel_id created_kernel_id = await ensure_async(self.delegate_multi_kernel_manager.start_kernel( kernel_name=self.kernel_name, **kwargs)) MixingKernelManager._kernel_id_map[self.kernel_id] = created_kernel_id async def shutdown_kernel(self, *args, **kwargs): await ensure_async(self.delegate_multi_kernel_manager.shutdown_kernel( self.delegate_kernel_id, *args, **kwargs)) MixingKernelManager._kernel_id_map.pop(self.kernel_id) async def interrupt_kernel(self): await ensure_async(self.delegate_multi_kernel_manager.interrupt_kernel( self.delegate_kernel_id)) async def restart_kernel(self, *args, **kwargs): await ensure_async(self.delegate_multi_kernel_manager.restart_kernel( self.delegate_kernel_id, *args, **kwargs)) async def model(self): delegate_model = await ensure_async( self.delegate_multi_kernel_manager.kernel_model(self.delegate_kernel_id)) model = copy.deepcopy(delegate_model) model["id"] = self.kernel_id return model