You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
575 lines
18 KiB
Python
575 lines
18 KiB
Python
import pytest
|
|
import sys
|
|
import socket
|
|
|
|
import attr
|
|
|
|
import trio
|
|
from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP
|
|
from trio._highlevel_open_tcp_stream import (
|
|
reorder_for_rfc_6555_section_5_4,
|
|
close_all,
|
|
open_tcp_stream,
|
|
format_host_port,
|
|
)
|
|
|
|
if sys.version_info < (3, 11):
|
|
from exceptiongroup import BaseExceptionGroup
|
|
|
|
|
|
def test_close_all():
|
|
class CloseMe:
|
|
closed = False
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
class CloseKiller:
|
|
def close(self):
|
|
raise OSError
|
|
|
|
c = CloseMe()
|
|
with close_all() as to_close:
|
|
to_close.add(c)
|
|
assert c.closed
|
|
|
|
c = CloseMe()
|
|
with pytest.raises(RuntimeError):
|
|
with close_all() as to_close:
|
|
to_close.add(c)
|
|
raise RuntimeError
|
|
assert c.closed
|
|
|
|
c = CloseMe()
|
|
with pytest.raises(OSError):
|
|
with close_all() as to_close:
|
|
to_close.add(CloseKiller())
|
|
to_close.add(c)
|
|
assert c.closed
|
|
|
|
|
|
def test_reorder_for_rfc_6555_section_5_4():
|
|
def fake4(i):
|
|
return (
|
|
AF_INET,
|
|
SOCK_STREAM,
|
|
IPPROTO_TCP,
|
|
"",
|
|
("10.0.0.{}".format(i), 80),
|
|
)
|
|
|
|
def fake6(i):
|
|
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("::{}".format(i), 80))
|
|
|
|
for fake in fake4, fake6:
|
|
# No effect on homogeneous lists
|
|
targets = [fake(0), fake(1), fake(2)]
|
|
reorder_for_rfc_6555_section_5_4(targets)
|
|
assert targets == [fake(0), fake(1), fake(2)]
|
|
|
|
# Single item lists also OK
|
|
targets = [fake(0)]
|
|
reorder_for_rfc_6555_section_5_4(targets)
|
|
assert targets == [fake(0)]
|
|
|
|
# If the list starts out with different families in positions 0 and 1,
|
|
# then it's left alone
|
|
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
|
|
targets = list(orig)
|
|
reorder_for_rfc_6555_section_5_4(targets)
|
|
assert targets == orig
|
|
|
|
# If not, it's reordered
|
|
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
|
|
reorder_for_rfc_6555_section_5_4(targets)
|
|
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
|
|
|
|
|
|
def test_format_host_port():
|
|
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
|
|
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
|
|
assert format_host_port("example.com", 443) == "example.com:443"
|
|
assert format_host_port(b"example.com", 443) == "example.com:443"
|
|
assert format_host_port("::1", "http") == "[::1]:http"
|
|
assert format_host_port(b"::1", "http") == "[::1]:http"
|
|
|
|
|
|
# Make sure we can connect to localhost using real kernel sockets
|
|
async def test_open_tcp_stream_real_socket_smoketest():
|
|
listen_sock = trio.socket.socket()
|
|
await listen_sock.bind(("127.0.0.1", 0))
|
|
_, listen_port = listen_sock.getsockname()
|
|
listen_sock.listen(1)
|
|
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
|
|
server_sock, _ = await listen_sock.accept()
|
|
await client_stream.send_all(b"x")
|
|
assert await server_sock.recv(1) == b"x"
|
|
await client_stream.aclose()
|
|
server_sock.close()
|
|
|
|
listen_sock.close()
|
|
|
|
|
|
async def test_open_tcp_stream_input_validation():
|
|
with pytest.raises(ValueError):
|
|
await open_tcp_stream(None, 80)
|
|
with pytest.raises(TypeError):
|
|
await open_tcp_stream("127.0.0.1", b"80")
|
|
|
|
|
|
def can_bind_127_0_0_2():
|
|
with socket.socket() as s:
|
|
try:
|
|
s.bind(("127.0.0.2", 0))
|
|
except OSError:
|
|
return False
|
|
return s.getsockname()[0] == "127.0.0.2"
|
|
|
|
|
|
async def test_local_address_real():
|
|
with trio.socket.socket() as listener:
|
|
await listener.bind(("127.0.0.1", 0))
|
|
listener.listen()
|
|
|
|
# It's hard to test local_address properly, because you need multiple
|
|
# local addresses that you can bind to. Fortunately, on most Linux
|
|
# systems, you can bind to any 127.*.*.* address, and they all go
|
|
# through the loopback interface. So we can use a non-standard
|
|
# loopback address. On other systems, the only address we know for
|
|
# certain we have is 127.0.0.1, so we can't really test local_address=
|
|
# properly -- passing local_address=127.0.0.1 is indistinguishable
|
|
# from not passing local_address= at all. But, we can still do a smoke
|
|
# test to make sure the local_address= code doesn't crash.
|
|
if can_bind_127_0_0_2():
|
|
local_address = "127.0.0.2"
|
|
else:
|
|
local_address = "127.0.0.1"
|
|
|
|
async with await open_tcp_stream(
|
|
*listener.getsockname(), local_address=local_address
|
|
) as client_stream:
|
|
assert client_stream.socket.getsockname()[0] == local_address
|
|
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
|
|
assert client_stream.socket.getsockopt(
|
|
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT
|
|
)
|
|
|
|
server_sock, remote_addr = await listener.accept()
|
|
await client_stream.aclose()
|
|
server_sock.close()
|
|
assert remote_addr[0] == local_address
|
|
|
|
# Trying to connect to an ipv4 address with the ipv6 wildcard
|
|
# local_address should fail
|
|
with pytest.raises(OSError):
|
|
await open_tcp_stream(*listener.getsockname(), local_address="::")
|
|
|
|
# But the ipv4 wildcard address should work
|
|
async with await open_tcp_stream(
|
|
*listener.getsockname(), local_address="0.0.0.0"
|
|
) as client_stream:
|
|
server_sock, remote_addr = await listener.accept()
|
|
server_sock.close()
|
|
assert remote_addr == client_stream.socket.getsockname()
|
|
|
|
|
|
# Now, thorough tests using fake sockets
|
|
|
|
|
|
@attr.s(eq=False)
|
|
class FakeSocket(trio.socket.SocketType):
|
|
scenario = attr.ib()
|
|
family = attr.ib()
|
|
type = attr.ib()
|
|
proto = attr.ib()
|
|
|
|
ip = attr.ib(default=None)
|
|
port = attr.ib(default=None)
|
|
succeeded = attr.ib(default=False)
|
|
closed = attr.ib(default=False)
|
|
failing = attr.ib(default=False)
|
|
|
|
async def connect(self, sockaddr):
|
|
self.ip = sockaddr[0]
|
|
self.port = sockaddr[1]
|
|
assert self.ip not in self.scenario.sockets
|
|
self.scenario.sockets[self.ip] = self
|
|
self.scenario.connect_times[self.ip] = trio.current_time()
|
|
delay, result = self.scenario.ip_dict[self.ip]
|
|
await trio.sleep(delay)
|
|
if result == "error":
|
|
raise OSError("sorry")
|
|
if result == "postconnect_fail":
|
|
self.failing = True
|
|
self.succeeded = True
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
# called when SocketStream is constructed
|
|
def setsockopt(self, *args, **kwargs):
|
|
if self.failing:
|
|
# raise something that isn't OSError as SocketStream
|
|
# ignores those
|
|
raise KeyboardInterrupt
|
|
|
|
|
|
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
|
|
def __init__(self, port, ip_list, supported_families):
|
|
# ip_list have to be unique
|
|
ip_order = [ip for (ip, _, _) in ip_list]
|
|
assert len(set(ip_order)) == len(ip_list)
|
|
ip_dict = {}
|
|
for ip, delay, result in ip_list:
|
|
assert 0 <= delay
|
|
assert result in ["error", "success", "postconnect_fail"]
|
|
ip_dict[ip] = (delay, result)
|
|
|
|
self.port = port
|
|
self.ip_order = ip_order
|
|
self.ip_dict = ip_dict
|
|
self.supported_families = supported_families
|
|
self.socket_count = 0
|
|
self.sockets = {}
|
|
self.connect_times = {}
|
|
|
|
def socket(self, family, type, proto):
|
|
if family not in self.supported_families:
|
|
raise OSError("pretending not to support this family")
|
|
self.socket_count += 1
|
|
return FakeSocket(self, family, type, proto)
|
|
|
|
def _ip_to_gai_entry(self, ip):
|
|
if ":" in ip:
|
|
family = trio.socket.AF_INET6
|
|
sockaddr = (ip, self.port, 0, 0)
|
|
else:
|
|
family = trio.socket.AF_INET
|
|
sockaddr = (ip, self.port)
|
|
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
|
|
|
|
async def getaddrinfo(self, host, port, family, type, proto, flags):
|
|
assert host == b"test.example.com"
|
|
assert port == self.port
|
|
assert family == trio.socket.AF_UNSPEC
|
|
assert type == trio.socket.SOCK_STREAM
|
|
assert proto == 0
|
|
assert flags == 0
|
|
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
|
|
|
|
async def getnameinfo(self, sockaddr, flags): # pragma: no cover
|
|
raise NotImplementedError
|
|
|
|
def check(self, succeeded):
|
|
# sockets only go into self.sockets when connect is called; make sure
|
|
# all the sockets that were created did in fact go in there.
|
|
assert self.socket_count == len(self.sockets)
|
|
|
|
for ip, socket in self.sockets.items():
|
|
assert ip in self.ip_dict
|
|
if socket is not succeeded:
|
|
assert socket.closed
|
|
assert socket.port == self.port
|
|
|
|
|
|
async def run_scenario(
|
|
# The port to connect to
|
|
port,
|
|
# A list of
|
|
# (ip, delay, result)
|
|
# tuples, where delay is in seconds and result is "success" or "error"
|
|
# The ip's will be returned from getaddrinfo in this order, and then
|
|
# connect() calls to them will have the given result.
|
|
ip_list,
|
|
*,
|
|
# If False, AF_INET4/6 sockets error out on creation, before connect is
|
|
# even called.
|
|
ipv4_supported=True,
|
|
ipv6_supported=True,
|
|
# Normally, we return (winning_sock, scenario object)
|
|
# If this is True, we require there to be an exception, and return
|
|
# (exception, scenario object)
|
|
expect_error=(),
|
|
**kwargs,
|
|
):
|
|
supported_families = set()
|
|
if ipv4_supported:
|
|
supported_families.add(trio.socket.AF_INET)
|
|
if ipv6_supported:
|
|
supported_families.add(trio.socket.AF_INET6)
|
|
scenario = Scenario(port, ip_list, supported_families)
|
|
trio.socket.set_custom_hostname_resolver(scenario)
|
|
trio.socket.set_custom_socket_factory(scenario)
|
|
|
|
try:
|
|
stream = await open_tcp_stream("test.example.com", port, **kwargs)
|
|
assert expect_error == ()
|
|
scenario.check(stream.socket)
|
|
return (stream.socket, scenario)
|
|
except AssertionError: # pragma: no cover
|
|
raise
|
|
except expect_error as exc:
|
|
scenario.check(None)
|
|
return (exc, scenario)
|
|
|
|
|
|
async def test_one_host_quick_success(autojump_clock):
|
|
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
|
|
assert sock.ip == "1.2.3.4"
|
|
assert trio.current_time() == 0.123
|
|
|
|
|
|
async def test_one_host_slow_success(autojump_clock):
|
|
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
|
|
assert sock.ip == "1.2.3.4"
|
|
assert trio.current_time() == 100
|
|
|
|
|
|
async def test_one_host_quick_fail(autojump_clock):
|
|
exc, scenario = await run_scenario(
|
|
82, [("1.2.3.4", 0.123, "error")], expect_error=OSError
|
|
)
|
|
assert isinstance(exc, OSError)
|
|
assert trio.current_time() == 0.123
|
|
|
|
|
|
async def test_one_host_slow_fail(autojump_clock):
|
|
exc, scenario = await run_scenario(
|
|
83, [("1.2.3.4", 100, "error")], expect_error=OSError
|
|
)
|
|
assert isinstance(exc, OSError)
|
|
assert trio.current_time() == 100
|
|
|
|
|
|
async def test_one_host_failed_after_connect(autojump_clock):
|
|
exc, scenario = await run_scenario(
|
|
83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt
|
|
)
|
|
assert isinstance(exc, KeyboardInterrupt)
|
|
|
|
|
|
# With the default 0.250 second delay, the third attempt will win
|
|
async def test_basic_fallthrough(autojump_clock):
|
|
sock, scenario = await run_scenario(
|
|
80,
|
|
[
|
|
("1.1.1.1", 1, "success"),
|
|
("2.2.2.2", 1, "success"),
|
|
("3.3.3.3", 0.2, "success"),
|
|
],
|
|
)
|
|
assert sock.ip == "3.3.3.3"
|
|
# current time is default time + default time + connection time
|
|
assert trio.current_time() == (0.250 + 0.250 + 0.2)
|
|
assert scenario.connect_times == {
|
|
"1.1.1.1": 0,
|
|
"2.2.2.2": 0.250,
|
|
"3.3.3.3": 0.500,
|
|
}
|
|
|
|
|
|
async def test_early_success(autojump_clock):
|
|
sock, scenario = await run_scenario(
|
|
80,
|
|
[
|
|
("1.1.1.1", 1, "success"),
|
|
("2.2.2.2", 0.1, "success"),
|
|
("3.3.3.3", 0.2, "success"),
|
|
],
|
|
)
|
|
assert sock.ip == "2.2.2.2"
|
|
assert trio.current_time() == (0.250 + 0.1)
|
|
assert scenario.connect_times == {
|
|
"1.1.1.1": 0,
|
|
"2.2.2.2": 0.250,
|
|
# 3.3.3.3 was never even started
|
|
}
|
|
|
|
|
|
# With a 0.450 second delay, the first attempt will win
|
|
async def test_custom_delay(autojump_clock):
|
|
sock, scenario = await run_scenario(
|
|
80,
|
|
[
|
|
("1.1.1.1", 1, "success"),
|
|
("2.2.2.2", 1, "success"),
|
|
("3.3.3.3", 0.2, "success"),
|
|
],
|
|
happy_eyeballs_delay=0.450,
|
|
)
|
|
assert sock.ip == "1.1.1.1"
|
|
assert trio.current_time() == 1
|
|
assert scenario.connect_times == {
|
|
"1.1.1.1": 0,
|
|
"2.2.2.2": 0.450,
|
|
"3.3.3.3": 0.900,
|
|
}
|
|
|
|
|
|
async def test_custom_errors_expedite(autojump_clock):
|
|
sock, scenario = await run_scenario(
|
|
80,
|
|
[
|
|
("1.1.1.1", 0.1, "error"),
|
|
("2.2.2.2", 0.2, "error"),
|
|
("3.3.3.3", 10, "success"),
|
|
# .25 is the default timeout
|
|
("4.4.4.4", 0.25, "success"),
|
|
],
|
|
)
|
|
assert sock.ip == "4.4.4.4"
|
|
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
|
|
assert scenario.connect_times == {
|
|
"1.1.1.1": 0,
|
|
"2.2.2.2": 0.1,
|
|
"3.3.3.3": 0.1 + 0.2,
|
|
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
|
}
|
|
|
|
|
|
async def test_all_fail(autojump_clock):
|
|
exc, scenario = await run_scenario(
|
|
80,
|
|
[
|
|
("1.1.1.1", 0.1, "error"),
|
|
("2.2.2.2", 0.2, "error"),
|
|
("3.3.3.3", 10, "error"),
|
|
("4.4.4.4", 0.250, "error"),
|
|
],
|
|
expect_error=OSError,
|
|
)
|
|
assert isinstance(exc, OSError)
|
|
assert isinstance(exc.__cause__, BaseExceptionGroup)
|
|
assert len(exc.__cause__.exceptions) == 4
|
|
assert trio.current_time() == (0.1 + 0.2 + 10)
|
|
assert scenario.connect_times == {
|
|
"1.1.1.1": 0,
|
|
"2.2.2.2": 0.1,
|
|
"3.3.3.3": 0.1 + 0.2,
|
|
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
|
}
|
|
|
|
|
|
async def test_multi_success(autojump_clock):
|
|
sock, scenario = await run_scenario(
|
|
80,
|
|
[
|
|
("1.1.1.1", 0.5, "error"),
|
|
("2.2.2.2", 10, "success"),
|
|
("3.3.3.3", 10 - 1, "success"),
|
|
("4.4.4.4", 10 - 2, "success"),
|
|
("5.5.5.5", 0.5, "error"),
|
|
],
|
|
happy_eyeballs_delay=1,
|
|
)
|
|
assert not scenario.sockets["1.1.1.1"].succeeded
|
|
assert (
|
|
scenario.sockets["2.2.2.2"].succeeded
|
|
or scenario.sockets["3.3.3.3"].succeeded
|
|
or scenario.sockets["4.4.4.4"].succeeded
|
|
)
|
|
assert not scenario.sockets["5.5.5.5"].succeeded
|
|
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
|
|
assert trio.current_time() == (0.5 + 10)
|
|
assert scenario.connect_times == {
|
|
"1.1.1.1": 0,
|
|
"2.2.2.2": 0.5,
|
|
"3.3.3.3": 1.5,
|
|
"4.4.4.4": 2.5,
|
|
"5.5.5.5": 3.5,
|
|
}
|
|
|
|
|
|
async def test_does_reorder(autojump_clock):
|
|
sock, scenario = await run_scenario(
|
|
80,
|
|
[
|
|
("1.1.1.1", 10, "error"),
|
|
# This would win if we tried it first...
|
|
("2.2.2.2", 1, "success"),
|
|
# But in fact we try this first, because of section 5.4
|
|
("::3", 0.5, "success"),
|
|
],
|
|
happy_eyeballs_delay=1,
|
|
)
|
|
assert sock.ip == "::3"
|
|
assert trio.current_time() == 1 + 0.5
|
|
assert scenario.connect_times == {
|
|
"1.1.1.1": 0,
|
|
"::3": 1,
|
|
}
|
|
|
|
|
|
async def test_handles_no_ipv4(autojump_clock):
|
|
sock, scenario = await run_scenario(
|
|
80,
|
|
# Here the ipv6 addresses fail at socket creation time, so the connect
|
|
# configuration doesn't matter
|
|
[
|
|
("::1", 10, "success"),
|
|
("2.2.2.2", 0, "success"),
|
|
("::3", 0.1, "success"),
|
|
("4.4.4.4", 0, "success"),
|
|
],
|
|
happy_eyeballs_delay=1,
|
|
ipv4_supported=False,
|
|
)
|
|
assert sock.ip == "::3"
|
|
assert trio.current_time() == 1 + 0.1
|
|
assert scenario.connect_times == {
|
|
"::1": 0,
|
|
"::3": 1.0,
|
|
}
|
|
|
|
|
|
async def test_handles_no_ipv6(autojump_clock):
|
|
sock, scenario = await run_scenario(
|
|
80,
|
|
# Here the ipv6 addresses fail at socket creation time, so the connect
|
|
# configuration doesn't matter
|
|
[
|
|
("::1", 0, "success"),
|
|
("2.2.2.2", 10, "success"),
|
|
("::3", 0, "success"),
|
|
("4.4.4.4", 0.1, "success"),
|
|
],
|
|
happy_eyeballs_delay=1,
|
|
ipv6_supported=False,
|
|
)
|
|
assert sock.ip == "4.4.4.4"
|
|
assert trio.current_time() == 1 + 0.1
|
|
assert scenario.connect_times == {
|
|
"2.2.2.2": 0,
|
|
"4.4.4.4": 1.0,
|
|
}
|
|
|
|
|
|
async def test_no_hosts(autojump_clock):
|
|
exc, scenario = await run_scenario(80, [], expect_error=OSError)
|
|
assert "no results found" in str(exc)
|
|
|
|
|
|
async def test_cancel(autojump_clock):
|
|
with trio.move_on_after(5) as cancel_scope:
|
|
exc, scenario = await run_scenario(
|
|
80,
|
|
[
|
|
("1.1.1.1", 10, "success"),
|
|
("2.2.2.2", 10, "success"),
|
|
("3.3.3.3", 10, "success"),
|
|
("4.4.4.4", 10, "success"),
|
|
],
|
|
expect_error=BaseExceptionGroup,
|
|
)
|
|
# What comes out should be 1 or more Cancelled errors that all belong
|
|
# to this cancel_scope; this is the easiest way to check that
|
|
raise exc
|
|
assert cancel_scope.cancelled_caught
|
|
|
|
assert trio.current_time() == 5
|
|
|
|
# This should have been called already, but just to make sure, since the
|
|
# exception-handling logic in run_scenario is a bit complicated and the
|
|
# main thing we care about here is that all the sockets were cleaned up.
|
|
scenario.check(succeeded=False)
|