You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
4.4 KiB
Python
161 lines
4.4 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
from typing import no_type_check
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import zmq
|
|
from jupyter_client.session import Session
|
|
from tornado.ioloop import IOLoop
|
|
from zmq.eventloop.zmqstream import ZMQStream
|
|
|
|
from ipykernel.ipkernel import IPythonKernel
|
|
from ipykernel.kernelbase import Kernel
|
|
from ipykernel.zmqshell import ZMQInteractiveShell
|
|
|
|
try:
|
|
import resource
|
|
except ImportError:
|
|
# Windows
|
|
resource = None # type:ignore
|
|
|
|
|
|
# Handle resource limit
|
|
# Ensure a minimal soft limit of DEFAULT_SOFT if the current hard limit is at least that much.
|
|
if resource is not None:
|
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
|
|
DEFAULT_SOFT = 4096
|
|
if hard >= DEFAULT_SOFT:
|
|
soft = DEFAULT_SOFT
|
|
|
|
if hard < soft:
|
|
hard = soft
|
|
|
|
resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard))
|
|
|
|
|
|
# Enforce selector event loop on Windows.
|
|
if os.name == "nt":
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type:ignore
|
|
|
|
|
|
class KernelMixin:
|
|
log = logging.getLogger()
|
|
|
|
def _initialize(self):
|
|
self.context = context = zmq.Context()
|
|
self.iopub_socket = context.socket(zmq.PUB)
|
|
self.stdin_socket = context.socket(zmq.ROUTER)
|
|
self.session = Session()
|
|
self.test_sockets = [self.iopub_socket]
|
|
self.test_streams = []
|
|
|
|
for name in ["shell", "control"]:
|
|
socket = context.socket(zmq.ROUTER)
|
|
stream = ZMQStream(socket)
|
|
stream.on_send(self._on_send)
|
|
self.test_sockets.append(socket)
|
|
self.test_streams.append(stream)
|
|
setattr(self, f"{name}_stream", stream)
|
|
|
|
async def do_debug_request(self, msg):
|
|
return {}
|
|
|
|
def destroy(self):
|
|
for stream in self.test_streams:
|
|
stream.close()
|
|
for socket in self.test_sockets:
|
|
socket.close()
|
|
self.context.destroy()
|
|
|
|
@no_type_check
|
|
async def test_shell_message(self, *args, **kwargs):
|
|
msg_list = self._prep_msg(*args, **kwargs)
|
|
await self.dispatch_shell(msg_list)
|
|
self.shell_stream.flush()
|
|
return await self._wait_for_msg()
|
|
|
|
@no_type_check
|
|
async def test_control_message(self, *args, **kwargs):
|
|
msg_list = self._prep_msg(*args, **kwargs)
|
|
await self.process_control(msg_list)
|
|
self.control_stream.flush()
|
|
return await self._wait_for_msg()
|
|
|
|
def _on_send(self, msg, *args, **kwargs):
|
|
self._reply = msg
|
|
|
|
def _prep_msg(self, *args, **kwargs):
|
|
self._reply = None
|
|
raw_msg = self.session.msg(*args, **kwargs)
|
|
msg = self.session.serialize(raw_msg)
|
|
return [zmq.Message(m) for m in msg]
|
|
|
|
async def _wait_for_msg(self):
|
|
while not self._reply:
|
|
await asyncio.sleep(0.1)
|
|
_, msg = self.session.feed_identities(self._reply)
|
|
return self.session.deserialize(msg)
|
|
|
|
def _send_interupt_children(self):
|
|
# override to prevent deadlock
|
|
pass
|
|
|
|
|
|
class MockKernel(KernelMixin, Kernel): # type:ignore
|
|
implementation = "test"
|
|
implementation_version = "1.0"
|
|
language = "no-op"
|
|
language_version = "0.1"
|
|
language_info = {
|
|
"name": "test",
|
|
"mimetype": "text/plain",
|
|
"file_extension": ".txt",
|
|
}
|
|
banner = "test kernel"
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._initialize()
|
|
self.shell = MagicMock()
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def do_execute(
|
|
self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
|
|
):
|
|
if not silent:
|
|
stream_content = {"name": "stdout", "text": code}
|
|
self.send_response(self.iopub_socket, "stream", stream_content)
|
|
|
|
return {
|
|
"status": "ok",
|
|
# The base class increments the execution count
|
|
"execution_count": self.execution_count,
|
|
"payload": [],
|
|
"user_expressions": {},
|
|
}
|
|
|
|
|
|
class MockIPyKernel(KernelMixin, IPythonKernel): # type:ignore
|
|
def __init__(self, *args, **kwargs):
|
|
self._initialize()
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
@pytest.fixture
|
|
async def kernel():
|
|
kernel = MockKernel()
|
|
kernel.io_loop = IOLoop.current()
|
|
yield kernel
|
|
kernel.destroy()
|
|
|
|
|
|
@pytest.fixture
|
|
async def ipkernel():
|
|
kernel = MockIPyKernel()
|
|
kernel.io_loop = IOLoop.current()
|
|
yield kernel
|
|
kernel.destroy()
|
|
ZMQInteractiveShell.clear_instance()
|