You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
401 lines
12 KiB
Python
401 lines
12 KiB
Python
# This should eventually be cleaned up and become public, but for right now I'm just
|
|
# implementing enough to test DTLS.
|
|
|
|
# TODO:
|
|
# - user-defined routers
|
|
# - TCP
|
|
# - UDP broadcast
|
|
|
|
import trio
|
|
import attr
|
|
import ipaddress
|
|
from collections import deque
|
|
import errno
|
|
import os
|
|
from typing import Union, List, Optional
|
|
import enum
|
|
from contextlib import contextmanager
|
|
|
|
from trio._util import Final, NoPublicConstructor
|
|
|
|
IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
|
|
|
|
|
def _family_for(ip: IPAddress) -> int:
|
|
if isinstance(ip, ipaddress.IPv4Address):
|
|
return trio.socket.AF_INET
|
|
elif isinstance(ip, ipaddress.IPv6Address):
|
|
return trio.socket.AF_INET6
|
|
assert False # pragma: no cover
|
|
|
|
|
|
def _wildcard_ip_for(family: int) -> IPAddress:
|
|
if family == trio.socket.AF_INET:
|
|
return ipaddress.ip_address("0.0.0.0")
|
|
elif family == trio.socket.AF_INET6:
|
|
return ipaddress.ip_address("::")
|
|
else:
|
|
assert False
|
|
|
|
|
|
def _localhost_ip_for(family: int) -> IPAddress:
|
|
if family == trio.socket.AF_INET:
|
|
return ipaddress.ip_address("127.0.0.1")
|
|
elif family == trio.socket.AF_INET6:
|
|
return ipaddress.ip_address("::1")
|
|
else:
|
|
assert False
|
|
|
|
|
|
def _fake_err(code):
|
|
raise OSError(code, os.strerror(code))
|
|
|
|
|
|
def _scatter(data, buffers):
|
|
written = 0
|
|
for buf in buffers:
|
|
next_piece = data[written : written + len(buf)]
|
|
with memoryview(buf) as mbuf:
|
|
mbuf[: len(next_piece)] = next_piece
|
|
written += len(next_piece)
|
|
if written == len(data):
|
|
break
|
|
return written
|
|
|
|
|
|
@attr.frozen
|
|
class UDPEndpoint:
|
|
ip: IPAddress
|
|
port: int
|
|
|
|
def as_python_sockaddr(self):
|
|
sockaddr = (self.ip.compressed, self.port)
|
|
if isinstance(self.ip, ipaddress.IPv6Address):
|
|
sockaddr += (0, 0)
|
|
return sockaddr
|
|
|
|
@classmethod
|
|
def from_python_sockaddr(cls, sockaddr):
|
|
ip, port = sockaddr[:2]
|
|
return cls(ip=ipaddress.ip_address(ip), port=port)
|
|
|
|
|
|
@attr.frozen
|
|
class UDPBinding:
|
|
local: UDPEndpoint
|
|
|
|
|
|
@attr.frozen
|
|
class UDPPacket:
|
|
source: UDPEndpoint
|
|
destination: UDPEndpoint
|
|
payload: bytes = attr.ib(repr=lambda p: p.hex())
|
|
|
|
def reply(self, payload):
|
|
return UDPPacket(
|
|
source=self.destination, destination=self.source, payload=payload
|
|
)
|
|
|
|
|
|
@attr.frozen
|
|
class FakeSocketFactory(trio.abc.SocketFactory):
|
|
fake_net: "FakeNet"
|
|
|
|
def socket(self, family: int, type: int, proto: int) -> "FakeSocket":
|
|
return FakeSocket._create(self.fake_net, family, type, proto)
|
|
|
|
|
|
@attr.frozen
|
|
class FakeHostnameResolver(trio.abc.HostnameResolver):
|
|
fake_net: "FakeNet"
|
|
|
|
async def getaddrinfo(
|
|
self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0
|
|
):
|
|
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
|
|
|
|
async def getnameinfo(self, sockaddr, flags: int):
|
|
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
|
|
|
|
|
|
class FakeNet(metaclass=Final):
|
|
def __init__(self):
|
|
# When we need to pick an arbitrary unique ip address/port, use these:
|
|
self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts()
|
|
self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts()
|
|
self._auto_port_iter = iter(range(50000, 65535))
|
|
|
|
self._bound: Dict[UDPBinding, FakeSocket] = {}
|
|
|
|
self.route_packet = None
|
|
|
|
def _bind(self, binding: UDPBinding, socket: "FakeSocket") -> None:
|
|
if binding in self._bound:
|
|
_fake_err(errno.EADDRINUSE)
|
|
self._bound[binding] = socket
|
|
|
|
def enable(self) -> None:
|
|
trio.socket.set_custom_socket_factory(FakeSocketFactory(self))
|
|
trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self))
|
|
|
|
def send_packet(self, packet) -> None:
|
|
if self.route_packet is None:
|
|
self.deliver_packet(packet)
|
|
else:
|
|
self.route_packet(packet)
|
|
|
|
def deliver_packet(self, packet) -> None:
|
|
binding = UDPBinding(local=packet.destination)
|
|
if binding in self._bound:
|
|
self._bound[binding]._deliver_packet(packet)
|
|
else:
|
|
# No valid destination, so drop it
|
|
pass
|
|
|
|
|
|
class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
|
|
def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int):
|
|
self._fake_net = fake_net
|
|
|
|
if not family:
|
|
family = trio.socket.AF_INET
|
|
if not type:
|
|
type = trio.socket.SOCK_STREAM
|
|
|
|
if family not in (trio.socket.AF_INET, trio.socket.AF_INET6):
|
|
raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}")
|
|
if type != trio.socket.SOCK_DGRAM:
|
|
raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}")
|
|
|
|
self.family = family
|
|
self.type = type
|
|
self.proto = proto
|
|
|
|
self._closed = False
|
|
|
|
self._packet_sender, self._packet_receiver = trio.open_memory_channel(
|
|
float("inf")
|
|
)
|
|
|
|
# This is the source-of-truth for what port etc. this socket is bound to
|
|
self._binding: Optional[UDPBinding] = None
|
|
|
|
def _check_closed(self):
|
|
if self._closed:
|
|
_fake_err(errno.EBADF)
|
|
|
|
def close(self):
|
|
# breakpoint()
|
|
if self._closed:
|
|
return
|
|
self._closed = True
|
|
if self._binding is not None:
|
|
del self._fake_net._bound[self._binding]
|
|
self._packet_receiver.close()
|
|
|
|
async def _resolve_address_nocp(self, address, *, local):
|
|
return await trio._socket._resolve_address_nocp(
|
|
self.type,
|
|
self.family,
|
|
self.proto,
|
|
address=address,
|
|
ipv6_v6only=False,
|
|
local=local,
|
|
)
|
|
|
|
def _deliver_packet(self, packet: UDPPacket):
|
|
try:
|
|
self._packet_sender.send_nowait(packet)
|
|
except trio.BrokenResourceError:
|
|
# sending to a closed socket -- UDP packets get dropped
|
|
pass
|
|
|
|
################################################################
|
|
# Actual IO operation implementations
|
|
################################################################
|
|
|
|
async def bind(self, addr):
|
|
self._check_closed()
|
|
if self._binding is not None:
|
|
_fake_error(errno.EINVAL)
|
|
await trio.lowlevel.checkpoint()
|
|
ip_str, port = await self._resolve_address_nocp(addr, local=True)
|
|
ip = ipaddress.ip_address(ip_str)
|
|
assert _family_for(ip) == self.family
|
|
# We convert binds to INET_ANY into binds to localhost
|
|
if ip == ipaddress.ip_address("0.0.0.0"):
|
|
ip = ipaddress.ip_address("127.0.0.1")
|
|
elif ip == ipaddress.ip_address("::"):
|
|
ip = ipaddress.ip_address("::1")
|
|
if port == 0:
|
|
port = next(self._fake_net._auto_port_iter)
|
|
binding = UDPBinding(local=UDPEndpoint(ip, port))
|
|
self._fake_net._bind(binding, self)
|
|
self._binding = binding
|
|
|
|
async def connect(self, peer):
|
|
raise NotImplementedError("FakeNet does not (yet) support connected sockets")
|
|
|
|
async def sendmsg(self, *args):
|
|
self._check_closed()
|
|
ancdata = []
|
|
flags = 0
|
|
address = None
|
|
if len(args) == 1:
|
|
(buffers,) = args
|
|
elif len(args) == 2:
|
|
buffers, address = args
|
|
elif len(args) == 3:
|
|
buffers, flags, address = args
|
|
elif len(args) == 4:
|
|
buffers, ancdata, flags, address = args
|
|
else:
|
|
raise TypeError("wrong number of arguments")
|
|
|
|
await trio.lowlevel.checkpoint()
|
|
|
|
if address is not None:
|
|
address = await self._resolve_address_nocp(address, local=False)
|
|
if ancdata:
|
|
raise NotImplementedError("FakeNet doesn't support ancillary data")
|
|
if flags:
|
|
raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}")
|
|
|
|
if address is None:
|
|
_fake_err(errno.ENOTCONN)
|
|
|
|
destination = UDPEndpoint.from_python_sockaddr(address)
|
|
|
|
if self._binding is None:
|
|
await self.bind((_wildcard_ip_for(self.family).compressed, 0))
|
|
|
|
payload = b"".join(buffers)
|
|
|
|
packet = UDPPacket(
|
|
source=self._binding.local,
|
|
destination=destination,
|
|
payload=payload,
|
|
)
|
|
|
|
self._fake_net.send_packet(packet)
|
|
|
|
return len(payload)
|
|
|
|
async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
|
|
if ancbufsize != 0:
|
|
raise NotImplementedError("FakeNet doesn't support ancillary data")
|
|
if flags != 0:
|
|
raise NotImplementedError("FakeNet doesn't support any recv flags")
|
|
|
|
self._check_closed()
|
|
|
|
ancdata = []
|
|
msg_flags = 0
|
|
|
|
packet = await self._packet_receiver.receive()
|
|
address = packet.source.as_python_sockaddr()
|
|
written = _scatter(packet.payload, buffers)
|
|
if written < len(packet.payload):
|
|
msg_flags |= trio.socket.MSG_TRUNC
|
|
return written, ancdata, msg_flags, address
|
|
|
|
################################################################
|
|
# Simple state query stuff
|
|
################################################################
|
|
|
|
def getsockname(self):
|
|
self._check_closed()
|
|
if self._binding is not None:
|
|
return self._binding.local.as_python_sockaddr()
|
|
elif self.family == trio.socket.AF_INET:
|
|
return ("0.0.0.0", 0)
|
|
else:
|
|
assert self.family == trio.socket.AF_INET6
|
|
return ("::", 0)
|
|
|
|
def getpeername(self):
|
|
self._check_closed()
|
|
if self._binding is not None:
|
|
if self._binding.remote is not None:
|
|
return self._binding.remote.as_python_sockaddr()
|
|
_fake_err(errno.ENOTCONN)
|
|
|
|
def getsockopt(self, level, item):
|
|
self._check_closed()
|
|
raise OSError(f"FakeNet doesn't implement getsockopt({level}, {item})")
|
|
|
|
def setsockopt(self, level, item, value):
|
|
self._check_closed()
|
|
|
|
if (level, item) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY):
|
|
if not value:
|
|
raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
|
|
|
|
raise OSError(f"FakeNet doesn't implement setsockopt({level}, {item}, ...)")
|
|
|
|
################################################################
|
|
# Various boilerplate and trivial stubs
|
|
################################################################
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, *exc_info):
|
|
self.close()
|
|
|
|
async def send(self, data, flags=0):
|
|
return await self.sendto(data, flags, None)
|
|
|
|
async def sendto(self, *args):
|
|
if len(args) == 2:
|
|
data, address = args
|
|
flags = 0
|
|
elif len(args) == 3:
|
|
data, flags, address = args
|
|
else:
|
|
raise TypeError("wrong number of arguments")
|
|
return await self.sendmsg([data], [], flags, address)
|
|
|
|
async def recv(self, bufsize, flags=0):
|
|
data, address = await self.recvfrom(bufsize, flags)
|
|
return data
|
|
|
|
async def recv_into(self, buf, nbytes=0, flags=0):
|
|
got_bytes, address = await self.recvfrom_into(buf, nbytes, flags)
|
|
return got_bytes
|
|
|
|
async def recvfrom(self, bufsize, flags=0):
|
|
data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags)
|
|
return data, address
|
|
|
|
async def recvfrom_into(self, buf, nbytes=0, flags=0):
|
|
if nbytes != 0 and nbytes != len(buf):
|
|
raise NotImplementedError("partial recvfrom_into")
|
|
got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
|
|
[buf], 0, flags
|
|
)
|
|
return got_nbytes, address
|
|
|
|
async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
|
|
buf = bytearray(bufsize)
|
|
got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
|
|
[buf], ancbufsize, flags
|
|
)
|
|
return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address)
|
|
|
|
def fileno(self):
|
|
raise NotImplementedError("can't get fileno() for FakeNet sockets")
|
|
|
|
def detach(self):
|
|
raise NotImplementedError("can't detach() a FakeNet socket")
|
|
|
|
def get_inheritable(self):
|
|
return False
|
|
|
|
def set_inheritable(self, inheritable):
|
|
if inheritable:
|
|
raise NotImplementedError("FakeNet can't make inheritable sockets")
|
|
|
|
def share(self, process_id):
|
|
raise NotImplementedError("FakeNet can't share sockets")
|