import asyncio import logging import os from typing import no_type_check from unittest.mock import MagicMock import pytest import zmq from jupyter_client.session import Session from tornado.ioloop import IOLoop from zmq.eventloop.zmqstream import ZMQStream from ipykernel.ipkernel import IPythonKernel from ipykernel.kernelbase import Kernel from ipykernel.zmqshell import ZMQInteractiveShell try: import resource except ImportError: # Windows resource = None # type:ignore # Handle resource limit # Ensure a minimal soft limit of DEFAULT_SOFT if the current hard limit is at least that much. if resource is not None: soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) DEFAULT_SOFT = 4096 if hard >= DEFAULT_SOFT: soft = DEFAULT_SOFT if hard < soft: hard = soft resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) # Enforce selector event loop on Windows. if os.name == "nt": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type:ignore class KernelMixin: log = logging.getLogger() def _initialize(self): self.context = context = zmq.Context() self.iopub_socket = context.socket(zmq.PUB) self.stdin_socket = context.socket(zmq.ROUTER) self.session = Session() self.test_sockets = [self.iopub_socket] self.test_streams = [] for name in ["shell", "control"]: socket = context.socket(zmq.ROUTER) stream = ZMQStream(socket) stream.on_send(self._on_send) self.test_sockets.append(socket) self.test_streams.append(stream) setattr(self, f"{name}_stream", stream) async def do_debug_request(self, msg): return {} def destroy(self): for stream in self.test_streams: stream.close() for socket in self.test_sockets: socket.close() self.context.destroy() @no_type_check async def test_shell_message(self, *args, **kwargs): msg_list = self._prep_msg(*args, **kwargs) await self.dispatch_shell(msg_list) self.shell_stream.flush() return await self._wait_for_msg() @no_type_check async def test_control_message(self, *args, **kwargs): msg_list = self._prep_msg(*args, **kwargs) await self.process_control(msg_list) self.control_stream.flush() return await self._wait_for_msg() def _on_send(self, msg, *args, **kwargs): self._reply = msg def _prep_msg(self, *args, **kwargs): self._reply = None raw_msg = self.session.msg(*args, **kwargs) msg = self.session.serialize(raw_msg) return [zmq.Message(m) for m in msg] async def _wait_for_msg(self): while not self._reply: await asyncio.sleep(0.1) _, msg = self.session.feed_identities(self._reply) return self.session.deserialize(msg) def _send_interupt_children(self): # override to prevent deadlock pass class MockKernel(KernelMixin, Kernel): # type:ignore implementation = "test" implementation_version = "1.0" language = "no-op" language_version = "0.1" language_info = { "name": "test", "mimetype": "text/plain", "file_extension": ".txt", } banner = "test kernel" def __init__(self, *args, **kwargs): self._initialize() self.shell = MagicMock() super().__init__(*args, **kwargs) def do_execute( self, code, silent, store_history=True, user_expressions=None, allow_stdin=False ): if not silent: stream_content = {"name": "stdout", "text": code} self.send_response(self.iopub_socket, "stream", stream_content) return { "status": "ok", # The base class increments the execution count "execution_count": self.execution_count, "payload": [], "user_expressions": {}, } class MockIPyKernel(KernelMixin, IPythonKernel): # type:ignore def __init__(self, *args, **kwargs): self._initialize() super().__init__(*args, **kwargs) @pytest.fixture async def kernel(): kernel = MockKernel() kernel.io_loop = IOLoop.current() yield kernel kernel.destroy() @pytest.fixture async def ipkernel(): kernel = MockIPyKernel() kernel.io_loop = IOLoop.current() yield kernel kernel.destroy() ZMQInteractiveShell.clear_instance()