You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
4.4 KiB
Python

2 years ago
import asyncio
import logging
import os
from typing import no_type_check
from unittest.mock import MagicMock
import pytest
import zmq
from jupyter_client.session import Session
from tornado.ioloop import IOLoop
from zmq.eventloop.zmqstream import ZMQStream
from ipykernel.ipkernel import IPythonKernel
from ipykernel.kernelbase import Kernel
from ipykernel.zmqshell import ZMQInteractiveShell
try:
import resource
except ImportError:
# Windows
resource = None # type:ignore
# Handle resource limit
# Ensure a minimal soft limit of DEFAULT_SOFT if the current hard limit is at least that much.
if resource is not None:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
DEFAULT_SOFT = 4096
if hard >= DEFAULT_SOFT:
soft = DEFAULT_SOFT
if hard < soft:
hard = soft
resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard))
# Enforce selector event loop on Windows.
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type:ignore
class KernelMixin:
log = logging.getLogger()
def _initialize(self):
self.context = context = zmq.Context()
self.iopub_socket = context.socket(zmq.PUB)
self.stdin_socket = context.socket(zmq.ROUTER)
self.session = Session()
self.test_sockets = [self.iopub_socket]
self.test_streams = []
for name in ["shell", "control"]:
socket = context.socket(zmq.ROUTER)
stream = ZMQStream(socket)
stream.on_send(self._on_send)
self.test_sockets.append(socket)
self.test_streams.append(stream)
setattr(self, f"{name}_stream", stream)
async def do_debug_request(self, msg):
return {}
def destroy(self):
for stream in self.test_streams:
stream.close()
for socket in self.test_sockets:
socket.close()
self.context.destroy()
@no_type_check
async def test_shell_message(self, *args, **kwargs):
msg_list = self._prep_msg(*args, **kwargs)
await self.dispatch_shell(msg_list)
self.shell_stream.flush()
return await self._wait_for_msg()
@no_type_check
async def test_control_message(self, *args, **kwargs):
msg_list = self._prep_msg(*args, **kwargs)
await self.process_control(msg_list)
self.control_stream.flush()
return await self._wait_for_msg()
def _on_send(self, msg, *args, **kwargs):
self._reply = msg
def _prep_msg(self, *args, **kwargs):
self._reply = None
raw_msg = self.session.msg(*args, **kwargs)
msg = self.session.serialize(raw_msg)
return [zmq.Message(m) for m in msg]
async def _wait_for_msg(self):
while not self._reply:
await asyncio.sleep(0.1)
_, msg = self.session.feed_identities(self._reply)
return self.session.deserialize(msg)
def _send_interupt_children(self):
# override to prevent deadlock
pass
class MockKernel(KernelMixin, Kernel): # type:ignore
implementation = "test"
implementation_version = "1.0"
language = "no-op"
language_version = "0.1"
language_info = {
"name": "test",
"mimetype": "text/plain",
"file_extension": ".txt",
}
banner = "test kernel"
def __init__(self, *args, **kwargs):
self._initialize()
self.shell = MagicMock()
super().__init__(*args, **kwargs)
def do_execute(
self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
):
if not silent:
stream_content = {"name": "stdout", "text": code}
self.send_response(self.iopub_socket, "stream", stream_content)
return {
"status": "ok",
# The base class increments the execution count
"execution_count": self.execution_count,
"payload": [],
"user_expressions": {},
}
class MockIPyKernel(KernelMixin, IPythonKernel): # type:ignore
def __init__(self, *args, **kwargs):
self._initialize()
super().__init__(*args, **kwargs)
@pytest.fixture
async def kernel():
kernel = MockKernel()
kernel.io_loop = IOLoop.current()
yield kernel
kernel.destroy()
@pytest.fixture
async def ipkernel():
kernel = MockIPyKernel()
kernel.io_loop = IOLoop.current()
yield kernel
kernel.destroy()
ZMQInteractiveShell.clear_instance()