You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

342 lines
12 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding: utf-8
# Little utilities we use internally
from abc import ABCMeta
import os
import signal
import sys
import pathlib
from functools import wraps, update_wrapper
import typing as t
import threading
import collections
from async_generator import isasyncgen
import trio
# Equivalent to the C function raise(), which Python doesn't wrap
if os.name == "nt":
# On windows, os.kill exists but is really weird.
#
# If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver
# those using GenerateConsoleCtrlEvent. But I found that when I tried
# to run my test normally, it would freeze waiting... unless I added
# print statements, in which case the test suddenly worked. So I guess
# these signals are only delivered if/when you access the console? I
# don't really know what was going on there. From reading the
# GenerateConsoleCtrlEvent docs I don't know how it worked at all.
#
# I later spent a bunch of time trying to make GenerateConsoleCtrlEvent
# work for creating synthetic control-C events, and... failed
# utterly. There are lots of details in the code and comments
# removed/added at this commit:
# https://github.com/python-trio/trio/commit/95843654173e3e826c34d70a90b369ba6edf2c23
#
# OTOH, if you pass os.kill any *other* signal number... then CPython
# just calls TerminateProcess (wtf).
#
# So, anyway, os.kill is not so useful for testing purposes. Instead
# we use raise():
#
# https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx
#
# Have to import cffi inside the 'if os.name' block because we don't
# depend on cffi on non-Windows platforms. (It would be easy to switch
# this to ctypes though if we ever remove the cffi dependency.)
#
# Some more information:
# https://bugs.python.org/issue26350
#
# Anyway, we use this for two things:
# - redelivering unhandled signals
# - generating synthetic signals for tests
# and for both of those purposes, 'raise' works fine.
import cffi
_ffi = cffi.FFI()
_ffi.cdef("int raise(int);")
_lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll")
signal_raise = getattr(_lib, "raise")
else:
def signal_raise(signum):
signal.pthread_kill(threading.get_ident(), signum)
# See: #461 as to why this is needed.
# The gist is that threading.main_thread() has the capability to lie to us
# if somebody else edits the threading ident cache to replace the main
# thread; causing threading.current_thread() to return a _DummyThread,
# causing the C-c check to fail, and so on.
# Trying to use signal out of the main thread will fail, so we can then
# reliably check if this is the main thread without relying on a
# potentially modified threading.
def is_main_thread():
"""Attempt to reliably check if we are in the main thread."""
try:
signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT))
return True
except ValueError:
return False
######
# Call the function and get the coroutine object, while giving helpful
# errors for common mistakes. Returns coroutine object.
######
def coroutine_or_error(async_fn, *args):
def _return_value_looks_like_wrong_library(value):
# Returned by legacy @asyncio.coroutine functions, which includes
# a surprising proportion of asyncio builtins.
if isinstance(value, collections.abc.Generator):
return True
# The protocol for detecting an asyncio Future-like object
if getattr(value, "_asyncio_future_blocking", None) is not None:
return True
# This janky check catches tornado Futures and twisted Deferreds.
# By the time we're calling this function, we already know
# something has gone wrong, so a heuristic is pretty safe.
if value.__class__.__name__ in ("Future", "Deferred"):
return True
return False
try:
coro = async_fn(*args)
except TypeError:
# Give good error for: nursery.start_soon(trio.sleep(1))
if isinstance(async_fn, collections.abc.Coroutine):
# explicitly close coroutine to avoid RuntimeWarning
async_fn.close()
raise TypeError(
"Trio was expecting an async function, but instead it got "
"a coroutine object {async_fn!r}\n"
"\n"
"Probably you did something like:\n"
"\n"
" trio.run({async_fn.__name__}(...)) # incorrect!\n"
" nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n"
"\n"
"Instead, you want (notice the parentheses!):\n"
"\n"
" trio.run({async_fn.__name__}, ...) # correct!\n"
" nursery.start_soon({async_fn.__name__}, ...) # correct!".format(
async_fn=async_fn
)
) from None
# Give good error for: nursery.start_soon(future)
if _return_value_looks_like_wrong_library(async_fn):
raise TypeError(
"Trio was expecting an async function, but instead it got "
"{!r} are you trying to use a library written for "
"asyncio/twisted/tornado or similar? That won't work "
"without some sort of compatibility shim.".format(async_fn)
) from None
raise
# We can't check iscoroutinefunction(async_fn), because that will fail
# for things like functools.partial objects wrapping an async
# function. So we have to just call it and then check whether the
# return value is a coroutine object.
if not isinstance(coro, collections.abc.Coroutine):
# Give good error for: nursery.start_soon(func_returning_future)
if _return_value_looks_like_wrong_library(coro):
raise TypeError(
"Trio got unexpected {!r} are you trying to use a "
"library written for asyncio/twisted/tornado or similar? "
"That won't work without some sort of compatibility shim.".format(coro)
)
if isasyncgen(coro):
raise TypeError(
"start_soon expected an async function but got an async "
"generator {!r}".format(coro)
)
# Give good error for: nursery.start_soon(some_sync_fn)
raise TypeError(
"Trio expected an async function, but {!r} appears to be "
"synchronous".format(getattr(async_fn, "__qualname__", async_fn))
)
return coro
class ConflictDetector:
"""Detect when two tasks are about to perform operations that would
conflict.
Use as a synchronous context manager; if two tasks enter it at the same
time then the second one raises an error. You can use it when there are
two pieces of code that *would* collide and need a lock if they ever were
called at the same time, but that should never happen.
We use this in particular for things like, making sure that two different
tasks don't call sendall simultaneously on the same stream.
"""
def __init__(self, msg):
self._msg = msg
self._held = False
def __enter__(self):
if self._held:
raise trio.BusyResourceError(self._msg)
else:
self._held = True
def __exit__(self, *args):
self._held = False
def async_wraps(cls, wrapped_cls, attr_name):
"""Similar to wraps, but for async wrappers of non-async functions."""
def decorator(func):
func.__name__ = attr_name
func.__qualname__ = ".".join((cls.__qualname__, attr_name))
func.__doc__ = """Like :meth:`~{}.{}.{}`, but async.
""".format(
wrapped_cls.__module__, wrapped_cls.__qualname__, attr_name
)
return func
return decorator
def fixup_module_metadata(module_name, namespace):
seen_ids = set()
def fix_one(qualname, name, obj):
# avoid infinite recursion (relevant when using
# typing.Generic, for example)
if id(obj) in seen_ids:
return
seen_ids.add(id(obj))
mod = getattr(obj, "__module__", None)
if mod is not None and mod.startswith("trio."):
obj.__module__ = module_name
# Modules, unlike everything else in Python, put fully-qualitied
# names into their __name__ attribute. We check for "." to avoid
# rewriting these.
if hasattr(obj, "__name__") and "." not in obj.__name__:
obj.__name__ = name
obj.__qualname__ = qualname
if isinstance(obj, type):
for attr_name, attr_value in obj.__dict__.items():
fix_one(objname + "." + attr_name, attr_name, attr_value)
for objname, obj in namespace.items():
if not objname.startswith("_"): # ignore private attributes
fix_one(objname, objname, obj)
class generic_function:
"""Decorator that makes a function indexable, to communicate
non-inferrable generic type parameters to a static type checker.
If you write::
@generic_function
def open_memory_channel(max_buffer_size: int) -> Tuple[
SendChannel[T], ReceiveChannel[T]
]: ...
it is valid at runtime to say ``open_memory_channel[bytes](5)``.
This behaves identically to ``open_memory_channel(5)`` at runtime,
and currently won't type-check without a mypy plugin or clever stubs,
but at least it becomes possible to write those.
"""
def __init__(self, fn):
update_wrapper(self, fn)
self._fn = fn
def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)
def __getitem__(self, _):
return self
class Final(ABCMeta):
"""Metaclass that enforces a class to be final (i.e., subclass not allowed).
If a class uses this metaclass like this::
class SomeClass(metaclass=Final):
pass
The metaclass will ensure that no sub class can be created.
Raises
------
- TypeError if a sub class is created
"""
def __new__(cls, name, bases, cls_namespace):
for base in bases:
if isinstance(base, Final):
raise TypeError(
f"{base.__module__}.{base.__qualname__} does not support subclassing"
)
return super().__new__(cls, name, bases, cls_namespace)
T = t.TypeVar("T")
class NoPublicConstructor(Final):
"""Metaclass that enforces a class to be final (i.e., subclass not allowed)
and ensures a private constructor.
If a class uses this metaclass like this::
class SomeClass(metaclass=NoPublicConstructor):
pass
The metaclass will ensure that no sub class can be created, and that no instance
can be initialized.
If you try to instantiate your class (SomeClass()), a TypeError will be thrown.
Raises
------
- TypeError if a sub class or an instance is created.
"""
def __call__(cls, *args, **kwargs):
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
)
def _create(cls: t.Type[T], *args: t.Any, **kwargs: t.Any) -> T:
return super().__call__(*args, **kwargs) # type: ignore
def name_asyncgen(agen):
"""Return the fully-qualified name of the async generator function
that produced the async generator iterator *agen*.
"""
if not hasattr(agen, "ag_code"): # pragma: no cover
return repr(agen)
try:
module = agen.ag_frame.f_globals["__name__"]
except (AttributeError, KeyError):
module = "<{}>".format(agen.ag_code.co_filename)
try:
qualname = agen.__qualname__
except AttributeError:
qualname = agen.ag_code.co_name
return f"{module}.{qualname}"