You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
8.3 KiB
Python

2 years ago
"""Base classes to manage a Client's interaction with a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import atexit
import time
import typing as t
from queue import Empty
from threading import Event
from threading import Thread
import zmq.asyncio
from .channelsabc import HBChannelABC
from .session import Session
from jupyter_client import protocol_version_info
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
# during garbage collection of threads at exit
# -----------------------------------------------------------------------------
# Constants and exceptions
# -----------------------------------------------------------------------------
major_protocol_version = protocol_version_info[0]
class InvalidPortNumber(Exception):
pass
class HBChannel(Thread):
"""The heartbeat channel which monitors the kernel heartbeat.
Note that the heartbeat channel is paused by default. As long as you start
this channel, the kernel manager will ensure that it is paused and un-paused
as appropriate.
"""
session = None
socket = None
address = None
_exiting = False
time_to_dead: float = 1.0
_running = None
_pause = None
_beating = None
def __init__(
self,
context: t.Optional[zmq.asyncio.Context] = None,
session: t.Optional[Session] = None,
address: t.Union[t.Tuple[str, int], str] = "",
):
"""Create the heartbeat monitor thread.
Parameters
----------
context : :class:`zmq.asyncio.Context`
The ZMQ context to use.
session : :class:`session.Session`
The session to use.
address : zmq url
Standard (ip, port) tuple that the kernel is listening on.
"""
super().__init__()
self.daemon = True
self.context = context
self.session = session
if isinstance(address, tuple):
if address[1] == 0:
message = "The port number for a channel cannot be 0."
raise InvalidPortNumber(message)
address_str = "tcp://%s:%i" % address
else:
address_str = address
self.address = address_str
# running is False until `.start()` is called
self._running = False
self._exit = Event()
# don't start paused
self._pause = False
self.poller = zmq.Poller()
@staticmethod
@atexit.register
def _notice_exit() -> None:
# Class definitions can be torn down during interpreter shutdown.
# We only need to set _exiting flag if this hasn't happened.
if HBChannel is not None:
HBChannel._exiting = True
def _create_socket(self) -> None:
if self.socket is not None:
# close previous socket, before opening a new one
self.poller.unregister(self.socket)
self.socket.close()
assert self.context is not None
self.socket = self.context.socket(zmq.REQ)
self.socket.linger = 1000
assert self.address is not None
self.socket.connect(self.address)
self.poller.register(self.socket, zmq.POLLIN)
def run(self) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._async_run())
loop.close()
async def _async_run(self) -> None:
"""The thread's main activity. Call start() instead."""
self._create_socket()
self._running = True
self._beating = True
assert self.socket is not None
while self._running:
if self._pause:
# just sleep, and skip the rest of the loop
self._exit.wait(self.time_to_dead)
continue
since_last_heartbeat = 0.0
# no need to catch EFSM here, because the previous event was
# either a recv or connect, which cannot be followed by EFSM
await self.socket.send(b"ping")
request_time = time.time()
# Wait until timeout
self._exit.wait(self.time_to_dead)
# poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
self._beating = bool(self.poller.poll(0))
if self._beating:
# the poll above guarantees we have something to recv
await self.socket.recv()
continue
elif self._running:
# nothing was received within the time limit, signal heart failure
since_last_heartbeat = time.time() - request_time
self.call_handlers(since_last_heartbeat)
# and close/reopen the socket, because the REQ/REP cycle has been broken
self._create_socket()
continue
def pause(self) -> None:
"""Pause the heartbeat."""
self._pause = True
def unpause(self) -> None:
"""Unpause the heartbeat."""
self._pause = False
def is_beating(self) -> bool:
"""Is the heartbeat running and responsive (and not paused)."""
if self.is_alive() and not self._pause and self._beating:
return True
else:
return False
def stop(self) -> None:
"""Stop the channel's event loop and join its thread."""
self._running = False
self._exit.set()
self.join()
self.close()
def close(self) -> None:
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
def call_handlers(self, since_last_heartbeat: float) -> None:
"""This method is called in the ioloop thread when a message arrives.
Subclasses should override this method to handle incoming messages.
It is important to remember that this method is called in the thread
so that some logic must be done to ensure that the application level
handlers are called in the application thread.
"""
pass
HBChannelABC.register(HBChannel)
class ZMQSocketChannel(object):
"""A ZMQ socket in an async API"""
def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None:
"""Create a channel.
Parameters
----------
socket : :class:`zmq.asyncio.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
Unused here, for other implementations
"""
super().__init__()
self.socket: t.Optional[zmq.asyncio.Socket] = socket
self.session = session
async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
assert self.socket is not None
msg = await self.socket.recv_multipart(**kwargs)
ident, smsg = self.session.feed_identities(msg)
return self.session.deserialize(smsg)
async def get_msg(self, timeout: t.Optional[float] = None) -> t.Dict[str, t.Any]:
"""Gets a message if there is one that is ready."""
assert self.socket is not None
if timeout is not None:
timeout *= 1000 # seconds to ms
ready = await self.socket.poll(timeout)
if ready:
res = await self._recv()
return res
else:
raise Empty
async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
"""Get all messages that are currently ready."""
msgs = []
while True:
try:
msgs.append(await self.get_msg())
except Empty:
break
return msgs
async def msg_ready(self) -> bool:
"""Is there a message that has been received?"""
assert self.socket is not None
return bool(await self.socket.poll(timeout=0))
def close(self) -> None:
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
stop = close
def is_alive(self) -> bool:
return self.socket is not None
def send(self, msg: t.Dict[str, t.Any]) -> None:
"""Pass a message to the ZMQ socket to send"""
assert self.socket is not None
self.session.send(self.socket, msg)
def start(self) -> None:
pass