You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

575 lines
18 KiB
Python

import pytest
import sys
import socket
import attr
import trio
from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP
from trio._highlevel_open_tcp_stream import (
reorder_for_rfc_6555_section_5_4,
close_all,
open_tcp_stream,
format_host_port,
)
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
def test_close_all():
class CloseMe:
closed = False
def close(self):
self.closed = True
class CloseKiller:
def close(self):
raise OSError
c = CloseMe()
with close_all() as to_close:
to_close.add(c)
assert c.closed
c = CloseMe()
with pytest.raises(RuntimeError):
with close_all() as to_close:
to_close.add(c)
raise RuntimeError
assert c.closed
c = CloseMe()
with pytest.raises(OSError):
with close_all() as to_close:
to_close.add(CloseKiller())
to_close.add(c)
assert c.closed
def test_reorder_for_rfc_6555_section_5_4():
def fake4(i):
return (
AF_INET,
SOCK_STREAM,
IPPROTO_TCP,
"",
("10.0.0.{}".format(i), 80),
)
def fake6(i):
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("::{}".format(i), 80))
for fake in fake4, fake6:
# No effect on homogeneous lists
targets = [fake(0), fake(1), fake(2)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake(0), fake(1), fake(2)]
# Single item lists also OK
targets = [fake(0)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake(0)]
# If the list starts out with different families in positions 0 and 1,
# then it's left alone
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
targets = list(orig)
reorder_for_rfc_6555_section_5_4(targets)
assert targets == orig
# If not, it's reordered
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
def test_format_host_port():
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port("example.com", 443) == "example.com:443"
assert format_host_port(b"example.com", 443) == "example.com:443"
assert format_host_port("::1", "http") == "[::1]:http"
assert format_host_port(b"::1", "http") == "[::1]:http"
# Make sure we can connect to localhost using real kernel sockets
async def test_open_tcp_stream_real_socket_smoketest():
listen_sock = trio.socket.socket()
await listen_sock.bind(("127.0.0.1", 0))
_, listen_port = listen_sock.getsockname()
listen_sock.listen(1)
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
server_sock, _ = await listen_sock.accept()
await client_stream.send_all(b"x")
assert await server_sock.recv(1) == b"x"
await client_stream.aclose()
server_sock.close()
listen_sock.close()
async def test_open_tcp_stream_input_validation():
with pytest.raises(ValueError):
await open_tcp_stream(None, 80)
with pytest.raises(TypeError):
await open_tcp_stream("127.0.0.1", b"80")
def can_bind_127_0_0_2():
with socket.socket() as s:
try:
s.bind(("127.0.0.2", 0))
except OSError:
return False
return s.getsockname()[0] == "127.0.0.2"
async def test_local_address_real():
with trio.socket.socket() as listener:
await listener.bind(("127.0.0.1", 0))
listener.listen()
# It's hard to test local_address properly, because you need multiple
# local addresses that you can bind to. Fortunately, on most Linux
# systems, you can bind to any 127.*.*.* address, and they all go
# through the loopback interface. So we can use a non-standard
# loopback address. On other systems, the only address we know for
# certain we have is 127.0.0.1, so we can't really test local_address=
# properly -- passing local_address=127.0.0.1 is indistinguishable
# from not passing local_address= at all. But, we can still do a smoke
# test to make sure the local_address= code doesn't crash.
if can_bind_127_0_0_2():
local_address = "127.0.0.2"
else:
local_address = "127.0.0.1"
async with await open_tcp_stream(
*listener.getsockname(), local_address=local_address
) as client_stream:
assert client_stream.socket.getsockname()[0] == local_address
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
assert client_stream.socket.getsockopt(
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT
)
server_sock, remote_addr = await listener.accept()
await client_stream.aclose()
server_sock.close()
assert remote_addr[0] == local_address
# Trying to connect to an ipv4 address with the ipv6 wildcard
# local_address should fail
with pytest.raises(OSError):
await open_tcp_stream(*listener.getsockname(), local_address="::")
# But the ipv4 wildcard address should work
async with await open_tcp_stream(
*listener.getsockname(), local_address="0.0.0.0"
) as client_stream:
server_sock, remote_addr = await listener.accept()
server_sock.close()
assert remote_addr == client_stream.socket.getsockname()
# Now, thorough tests using fake sockets
@attr.s(eq=False)
class FakeSocket(trio.socket.SocketType):
scenario = attr.ib()
family = attr.ib()
type = attr.ib()
proto = attr.ib()
ip = attr.ib(default=None)
port = attr.ib(default=None)
succeeded = attr.ib(default=False)
closed = attr.ib(default=False)
failing = attr.ib(default=False)
async def connect(self, sockaddr):
self.ip = sockaddr[0]
self.port = sockaddr[1]
assert self.ip not in self.scenario.sockets
self.scenario.sockets[self.ip] = self
self.scenario.connect_times[self.ip] = trio.current_time()
delay, result = self.scenario.ip_dict[self.ip]
await trio.sleep(delay)
if result == "error":
raise OSError("sorry")
if result == "postconnect_fail":
self.failing = True
self.succeeded = True
def close(self):
self.closed = True
# called when SocketStream is constructed
def setsockopt(self, *args, **kwargs):
if self.failing:
# raise something that isn't OSError as SocketStream
# ignores those
raise KeyboardInterrupt
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
def __init__(self, port, ip_list, supported_families):
# ip_list have to be unique
ip_order = [ip for (ip, _, _) in ip_list]
assert len(set(ip_order)) == len(ip_list)
ip_dict = {}
for ip, delay, result in ip_list:
assert 0 <= delay
assert result in ["error", "success", "postconnect_fail"]
ip_dict[ip] = (delay, result)
self.port = port
self.ip_order = ip_order
self.ip_dict = ip_dict
self.supported_families = supported_families
self.socket_count = 0
self.sockets = {}
self.connect_times = {}
def socket(self, family, type, proto):
if family not in self.supported_families:
raise OSError("pretending not to support this family")
self.socket_count += 1
return FakeSocket(self, family, type, proto)
def _ip_to_gai_entry(self, ip):
if ":" in ip:
family = trio.socket.AF_INET6
sockaddr = (ip, self.port, 0, 0)
else:
family = trio.socket.AF_INET
sockaddr = (ip, self.port)
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
async def getaddrinfo(self, host, port, family, type, proto, flags):
assert host == b"test.example.com"
assert port == self.port
assert family == trio.socket.AF_UNSPEC
assert type == trio.socket.SOCK_STREAM
assert proto == 0
assert flags == 0
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
async def getnameinfo(self, sockaddr, flags): # pragma: no cover
raise NotImplementedError
def check(self, succeeded):
# sockets only go into self.sockets when connect is called; make sure
# all the sockets that were created did in fact go in there.
assert self.socket_count == len(self.sockets)
for ip, socket in self.sockets.items():
assert ip in self.ip_dict
if socket is not succeeded:
assert socket.closed
assert socket.port == self.port
async def run_scenario(
# The port to connect to
port,
# A list of
# (ip, delay, result)
# tuples, where delay is in seconds and result is "success" or "error"
# The ip's will be returned from getaddrinfo in this order, and then
# connect() calls to them will have the given result.
ip_list,
*,
# If False, AF_INET4/6 sockets error out on creation, before connect is
# even called.
ipv4_supported=True,
ipv6_supported=True,
# Normally, we return (winning_sock, scenario object)
# If this is True, we require there to be an exception, and return
# (exception, scenario object)
expect_error=(),
**kwargs,
):
supported_families = set()
if ipv4_supported:
supported_families.add(trio.socket.AF_INET)
if ipv6_supported:
supported_families.add(trio.socket.AF_INET6)
scenario = Scenario(port, ip_list, supported_families)
trio.socket.set_custom_hostname_resolver(scenario)
trio.socket.set_custom_socket_factory(scenario)
try:
stream = await open_tcp_stream("test.example.com", port, **kwargs)
assert expect_error == ()
scenario.check(stream.socket)
return (stream.socket, scenario)
except AssertionError: # pragma: no cover
raise
except expect_error as exc:
scenario.check(None)
return (exc, scenario)
async def test_one_host_quick_success(autojump_clock):
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 0.123
async def test_one_host_slow_success(autojump_clock):
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 100
async def test_one_host_quick_fail(autojump_clock):
exc, scenario = await run_scenario(
82, [("1.2.3.4", 0.123, "error")], expect_error=OSError
)
assert isinstance(exc, OSError)
assert trio.current_time() == 0.123
async def test_one_host_slow_fail(autojump_clock):
exc, scenario = await run_scenario(
83, [("1.2.3.4", 100, "error")], expect_error=OSError
)
assert isinstance(exc, OSError)
assert trio.current_time() == 100
async def test_one_host_failed_after_connect(autojump_clock):
exc, scenario = await run_scenario(
83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt
)
assert isinstance(exc, KeyboardInterrupt)
# With the default 0.250 second delay, the third attempt will win
async def test_basic_fallthrough(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
)
assert sock.ip == "3.3.3.3"
# current time is default time + default time + connection time
assert trio.current_time() == (0.250 + 0.250 + 0.2)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
"3.3.3.3": 0.500,
}
async def test_early_success(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 0.1, "success"),
("3.3.3.3", 0.2, "success"),
],
)
assert sock.ip == "2.2.2.2"
assert trio.current_time() == (0.250 + 0.1)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
# 3.3.3.3 was never even started
}
# With a 0.450 second delay, the first attempt will win
async def test_custom_delay(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
happy_eyeballs_delay=0.450,
)
assert sock.ip == "1.1.1.1"
assert trio.current_time() == 1
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.450,
"3.3.3.3": 0.900,
}
async def test_custom_errors_expedite(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.1, "error"),
("2.2.2.2", 0.2, "error"),
("3.3.3.3", 10, "success"),
# .25 is the default timeout
("4.4.4.4", 0.25, "success"),
],
)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.1,
"3.3.3.3": 0.1 + 0.2,
"4.4.4.4": 0.1 + 0.2 + 0.25,
}
async def test_all_fail(autojump_clock):
exc, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.1, "error"),
("2.2.2.2", 0.2, "error"),
("3.3.3.3", 10, "error"),
("4.4.4.4", 0.250, "error"),
],
expect_error=OSError,
)
assert isinstance(exc, OSError)
assert isinstance(exc.__cause__, BaseExceptionGroup)
assert len(exc.__cause__.exceptions) == 4
assert trio.current_time() == (0.1 + 0.2 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.1,
"3.3.3.3": 0.1 + 0.2,
"4.4.4.4": 0.1 + 0.2 + 0.25,
}
async def test_multi_success(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.5, "error"),
("2.2.2.2", 10, "success"),
("3.3.3.3", 10 - 1, "success"),
("4.4.4.4", 10 - 2, "success"),
("5.5.5.5", 0.5, "error"),
],
happy_eyeballs_delay=1,
)
assert not scenario.sockets["1.1.1.1"].succeeded
assert (
scenario.sockets["2.2.2.2"].succeeded
or scenario.sockets["3.3.3.3"].succeeded
or scenario.sockets["4.4.4.4"].succeeded
)
assert not scenario.sockets["5.5.5.5"].succeeded
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
assert trio.current_time() == (0.5 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.5,
"3.3.3.3": 1.5,
"4.4.4.4": 2.5,
"5.5.5.5": 3.5,
}
async def test_does_reorder(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 10, "error"),
# This would win if we tried it first...
("2.2.2.2", 1, "success"),
# But in fact we try this first, because of section 5.4
("::3", 0.5, "success"),
],
happy_eyeballs_delay=1,
)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.5
assert scenario.connect_times == {
"1.1.1.1": 0,
"::3": 1,
}
async def test_handles_no_ipv4(autojump_clock):
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
# configuration doesn't matter
[
("::1", 10, "success"),
("2.2.2.2", 0, "success"),
("::3", 0.1, "success"),
("4.4.4.4", 0, "success"),
],
happy_eyeballs_delay=1,
ipv4_supported=False,
)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
"::1": 0,
"::3": 1.0,
}
async def test_handles_no_ipv6(autojump_clock):
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
# configuration doesn't matter
[
("::1", 0, "success"),
("2.2.2.2", 10, "success"),
("::3", 0, "success"),
("4.4.4.4", 0.1, "success"),
],
happy_eyeballs_delay=1,
ipv6_supported=False,
)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
"2.2.2.2": 0,
"4.4.4.4": 1.0,
}
async def test_no_hosts(autojump_clock):
exc, scenario = await run_scenario(80, [], expect_error=OSError)
assert "no results found" in str(exc)
async def test_cancel(autojump_clock):
with trio.move_on_after(5) as cancel_scope:
exc, scenario = await run_scenario(
80,
[
("1.1.1.1", 10, "success"),
("2.2.2.2", 10, "success"),
("3.3.3.3", 10, "success"),
("4.4.4.4", 10, "success"),
],
expect_error=BaseExceptionGroup,
)
# What comes out should be 1 or more Cancelled errors that all belong
# to this cancel_scope; this is the easiest way to check that
raise exc
assert cancel_scope.cancelled_caught
assert trio.current_time() == 5
# This should have been called already, but just to make sure, since the
# exception-handling logic in run_scenario is a bit complicated and the
# main thing we care about here is that all the sockets were cleaned up.
scenario.check(succeeded=False)