import pytest import sys import socket import attr import trio from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP from trio._highlevel_open_tcp_stream import ( reorder_for_rfc_6555_section_5_4, close_all, open_tcp_stream, format_host_port, ) if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup def test_close_all(): class CloseMe: closed = False def close(self): self.closed = True class CloseKiller: def close(self): raise OSError c = CloseMe() with close_all() as to_close: to_close.add(c) assert c.closed c = CloseMe() with pytest.raises(RuntimeError): with close_all() as to_close: to_close.add(c) raise RuntimeError assert c.closed c = CloseMe() with pytest.raises(OSError): with close_all() as to_close: to_close.add(CloseKiller()) to_close.add(c) assert c.closed def test_reorder_for_rfc_6555_section_5_4(): def fake4(i): return ( AF_INET, SOCK_STREAM, IPPROTO_TCP, "", ("10.0.0.{}".format(i), 80), ) def fake6(i): return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("::{}".format(i), 80)) for fake in fake4, fake6: # No effect on homogeneous lists targets = [fake(0), fake(1), fake(2)] reorder_for_rfc_6555_section_5_4(targets) assert targets == [fake(0), fake(1), fake(2)] # Single item lists also OK targets = [fake(0)] reorder_for_rfc_6555_section_5_4(targets) assert targets == [fake(0)] # If the list starts out with different families in positions 0 and 1, # then it's left alone orig = [fake4(0), fake6(0), fake4(1), fake6(1)] targets = list(orig) reorder_for_rfc_6555_section_5_4(targets) assert targets == orig # If not, it's reordered targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)] reorder_for_rfc_6555_section_5_4(targets) assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)] def test_format_host_port(): assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port("example.com", 443) == "example.com:443" assert format_host_port(b"example.com", 443) == "example.com:443" assert format_host_port("::1", "http") == "[::1]:http" assert format_host_port(b"::1", "http") == "[::1]:http" # Make sure we can connect to localhost using real kernel sockets async def test_open_tcp_stream_real_socket_smoketest(): listen_sock = trio.socket.socket() await listen_sock.bind(("127.0.0.1", 0)) _, listen_port = listen_sock.getsockname() listen_sock.listen(1) client_stream = await open_tcp_stream("127.0.0.1", listen_port) server_sock, _ = await listen_sock.accept() await client_stream.send_all(b"x") assert await server_sock.recv(1) == b"x" await client_stream.aclose() server_sock.close() listen_sock.close() async def test_open_tcp_stream_input_validation(): with pytest.raises(ValueError): await open_tcp_stream(None, 80) with pytest.raises(TypeError): await open_tcp_stream("127.0.0.1", b"80") def can_bind_127_0_0_2(): with socket.socket() as s: try: s.bind(("127.0.0.2", 0)) except OSError: return False return s.getsockname()[0] == "127.0.0.2" async def test_local_address_real(): with trio.socket.socket() as listener: await listener.bind(("127.0.0.1", 0)) listener.listen() # It's hard to test local_address properly, because you need multiple # local addresses that you can bind to. Fortunately, on most Linux # systems, you can bind to any 127.*.*.* address, and they all go # through the loopback interface. So we can use a non-standard # loopback address. On other systems, the only address we know for # certain we have is 127.0.0.1, so we can't really test local_address= # properly -- passing local_address=127.0.0.1 is indistinguishable # from not passing local_address= at all. But, we can still do a smoke # test to make sure the local_address= code doesn't crash. if can_bind_127_0_0_2(): local_address = "127.0.0.2" else: local_address = "127.0.0.1" async with await open_tcp_stream( *listener.getsockname(), local_address=local_address ) as client_stream: assert client_stream.socket.getsockname()[0] == local_address if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"): assert client_stream.socket.getsockopt( trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT ) server_sock, remote_addr = await listener.accept() await client_stream.aclose() server_sock.close() assert remote_addr[0] == local_address # Trying to connect to an ipv4 address with the ipv6 wildcard # local_address should fail with pytest.raises(OSError): await open_tcp_stream(*listener.getsockname(), local_address="::") # But the ipv4 wildcard address should work async with await open_tcp_stream( *listener.getsockname(), local_address="0.0.0.0" ) as client_stream: server_sock, remote_addr = await listener.accept() server_sock.close() assert remote_addr == client_stream.socket.getsockname() # Now, thorough tests using fake sockets @attr.s(eq=False) class FakeSocket(trio.socket.SocketType): scenario = attr.ib() family = attr.ib() type = attr.ib() proto = attr.ib() ip = attr.ib(default=None) port = attr.ib(default=None) succeeded = attr.ib(default=False) closed = attr.ib(default=False) failing = attr.ib(default=False) async def connect(self, sockaddr): self.ip = sockaddr[0] self.port = sockaddr[1] assert self.ip not in self.scenario.sockets self.scenario.sockets[self.ip] = self self.scenario.connect_times[self.ip] = trio.current_time() delay, result = self.scenario.ip_dict[self.ip] await trio.sleep(delay) if result == "error": raise OSError("sorry") if result == "postconnect_fail": self.failing = True self.succeeded = True def close(self): self.closed = True # called when SocketStream is constructed def setsockopt(self, *args, **kwargs): if self.failing: # raise something that isn't OSError as SocketStream # ignores those raise KeyboardInterrupt class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver): def __init__(self, port, ip_list, supported_families): # ip_list have to be unique ip_order = [ip for (ip, _, _) in ip_list] assert len(set(ip_order)) == len(ip_list) ip_dict = {} for ip, delay, result in ip_list: assert 0 <= delay assert result in ["error", "success", "postconnect_fail"] ip_dict[ip] = (delay, result) self.port = port self.ip_order = ip_order self.ip_dict = ip_dict self.supported_families = supported_families self.socket_count = 0 self.sockets = {} self.connect_times = {} def socket(self, family, type, proto): if family not in self.supported_families: raise OSError("pretending not to support this family") self.socket_count += 1 return FakeSocket(self, family, type, proto) def _ip_to_gai_entry(self, ip): if ":" in ip: family = trio.socket.AF_INET6 sockaddr = (ip, self.port, 0, 0) else: family = trio.socket.AF_INET sockaddr = (ip, self.port) return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr) async def getaddrinfo(self, host, port, family, type, proto, flags): assert host == b"test.example.com" assert port == self.port assert family == trio.socket.AF_UNSPEC assert type == trio.socket.SOCK_STREAM assert proto == 0 assert flags == 0 return [self._ip_to_gai_entry(ip) for ip in self.ip_order] async def getnameinfo(self, sockaddr, flags): # pragma: no cover raise NotImplementedError def check(self, succeeded): # sockets only go into self.sockets when connect is called; make sure # all the sockets that were created did in fact go in there. assert self.socket_count == len(self.sockets) for ip, socket in self.sockets.items(): assert ip in self.ip_dict if socket is not succeeded: assert socket.closed assert socket.port == self.port async def run_scenario( # The port to connect to port, # A list of # (ip, delay, result) # tuples, where delay is in seconds and result is "success" or "error" # The ip's will be returned from getaddrinfo in this order, and then # connect() calls to them will have the given result. ip_list, *, # If False, AF_INET4/6 sockets error out on creation, before connect is # even called. ipv4_supported=True, ipv6_supported=True, # Normally, we return (winning_sock, scenario object) # If this is True, we require there to be an exception, and return # (exception, scenario object) expect_error=(), **kwargs, ): supported_families = set() if ipv4_supported: supported_families.add(trio.socket.AF_INET) if ipv6_supported: supported_families.add(trio.socket.AF_INET6) scenario = Scenario(port, ip_list, supported_families) trio.socket.set_custom_hostname_resolver(scenario) trio.socket.set_custom_socket_factory(scenario) try: stream = await open_tcp_stream("test.example.com", port, **kwargs) assert expect_error == () scenario.check(stream.socket) return (stream.socket, scenario) except AssertionError: # pragma: no cover raise except expect_error as exc: scenario.check(None) return (exc, scenario) async def test_one_host_quick_success(autojump_clock): sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) assert sock.ip == "1.2.3.4" assert trio.current_time() == 0.123 async def test_one_host_slow_success(autojump_clock): sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")]) assert sock.ip == "1.2.3.4" assert trio.current_time() == 100 async def test_one_host_quick_fail(autojump_clock): exc, scenario = await run_scenario( 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError ) assert isinstance(exc, OSError) assert trio.current_time() == 0.123 async def test_one_host_slow_fail(autojump_clock): exc, scenario = await run_scenario( 83, [("1.2.3.4", 100, "error")], expect_error=OSError ) assert isinstance(exc, OSError) assert trio.current_time() == 100 async def test_one_host_failed_after_connect(autojump_clock): exc, scenario = await run_scenario( 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt ) assert isinstance(exc, KeyboardInterrupt) # With the default 0.250 second delay, the third attempt will win async def test_basic_fallthrough(autojump_clock): sock, scenario = await run_scenario( 80, [ ("1.1.1.1", 1, "success"), ("2.2.2.2", 1, "success"), ("3.3.3.3", 0.2, "success"), ], ) assert sock.ip == "3.3.3.3" # current time is default time + default time + connection time assert trio.current_time() == (0.250 + 0.250 + 0.2) assert scenario.connect_times == { "1.1.1.1": 0, "2.2.2.2": 0.250, "3.3.3.3": 0.500, } async def test_early_success(autojump_clock): sock, scenario = await run_scenario( 80, [ ("1.1.1.1", 1, "success"), ("2.2.2.2", 0.1, "success"), ("3.3.3.3", 0.2, "success"), ], ) assert sock.ip == "2.2.2.2" assert trio.current_time() == (0.250 + 0.1) assert scenario.connect_times == { "1.1.1.1": 0, "2.2.2.2": 0.250, # 3.3.3.3 was never even started } # With a 0.450 second delay, the first attempt will win async def test_custom_delay(autojump_clock): sock, scenario = await run_scenario( 80, [ ("1.1.1.1", 1, "success"), ("2.2.2.2", 1, "success"), ("3.3.3.3", 0.2, "success"), ], happy_eyeballs_delay=0.450, ) assert sock.ip == "1.1.1.1" assert trio.current_time() == 1 assert scenario.connect_times == { "1.1.1.1": 0, "2.2.2.2": 0.450, "3.3.3.3": 0.900, } async def test_custom_errors_expedite(autojump_clock): sock, scenario = await run_scenario( 80, [ ("1.1.1.1", 0.1, "error"), ("2.2.2.2", 0.2, "error"), ("3.3.3.3", 10, "success"), # .25 is the default timeout ("4.4.4.4", 0.25, "success"), ], ) assert sock.ip == "4.4.4.4" assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25) assert scenario.connect_times == { "1.1.1.1": 0, "2.2.2.2": 0.1, "3.3.3.3": 0.1 + 0.2, "4.4.4.4": 0.1 + 0.2 + 0.25, } async def test_all_fail(autojump_clock): exc, scenario = await run_scenario( 80, [ ("1.1.1.1", 0.1, "error"), ("2.2.2.2", 0.2, "error"), ("3.3.3.3", 10, "error"), ("4.4.4.4", 0.250, "error"), ], expect_error=OSError, ) assert isinstance(exc, OSError) assert isinstance(exc.__cause__, BaseExceptionGroup) assert len(exc.__cause__.exceptions) == 4 assert trio.current_time() == (0.1 + 0.2 + 10) assert scenario.connect_times == { "1.1.1.1": 0, "2.2.2.2": 0.1, "3.3.3.3": 0.1 + 0.2, "4.4.4.4": 0.1 + 0.2 + 0.25, } async def test_multi_success(autojump_clock): sock, scenario = await run_scenario( 80, [ ("1.1.1.1", 0.5, "error"), ("2.2.2.2", 10, "success"), ("3.3.3.3", 10 - 1, "success"), ("4.4.4.4", 10 - 2, "success"), ("5.5.5.5", 0.5, "error"), ], happy_eyeballs_delay=1, ) assert not scenario.sockets["1.1.1.1"].succeeded assert ( scenario.sockets["2.2.2.2"].succeeded or scenario.sockets["3.3.3.3"].succeeded or scenario.sockets["4.4.4.4"].succeeded ) assert not scenario.sockets["5.5.5.5"].succeeded assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"] assert trio.current_time() == (0.5 + 10) assert scenario.connect_times == { "1.1.1.1": 0, "2.2.2.2": 0.5, "3.3.3.3": 1.5, "4.4.4.4": 2.5, "5.5.5.5": 3.5, } async def test_does_reorder(autojump_clock): sock, scenario = await run_scenario( 80, [ ("1.1.1.1", 10, "error"), # This would win if we tried it first... ("2.2.2.2", 1, "success"), # But in fact we try this first, because of section 5.4 ("::3", 0.5, "success"), ], happy_eyeballs_delay=1, ) assert sock.ip == "::3" assert trio.current_time() == 1 + 0.5 assert scenario.connect_times == { "1.1.1.1": 0, "::3": 1, } async def test_handles_no_ipv4(autojump_clock): sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect # configuration doesn't matter [ ("::1", 10, "success"), ("2.2.2.2", 0, "success"), ("::3", 0.1, "success"), ("4.4.4.4", 0, "success"), ], happy_eyeballs_delay=1, ipv4_supported=False, ) assert sock.ip == "::3" assert trio.current_time() == 1 + 0.1 assert scenario.connect_times == { "::1": 0, "::3": 1.0, } async def test_handles_no_ipv6(autojump_clock): sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect # configuration doesn't matter [ ("::1", 0, "success"), ("2.2.2.2", 10, "success"), ("::3", 0, "success"), ("4.4.4.4", 0.1, "success"), ], happy_eyeballs_delay=1, ipv6_supported=False, ) assert sock.ip == "4.4.4.4" assert trio.current_time() == 1 + 0.1 assert scenario.connect_times == { "2.2.2.2": 0, "4.4.4.4": 1.0, } async def test_no_hosts(autojump_clock): exc, scenario = await run_scenario(80, [], expect_error=OSError) assert "no results found" in str(exc) async def test_cancel(autojump_clock): with trio.move_on_after(5) as cancel_scope: exc, scenario = await run_scenario( 80, [ ("1.1.1.1", 10, "success"), ("2.2.2.2", 10, "success"), ("3.3.3.3", 10, "success"), ("4.4.4.4", 10, "success"), ], expect_error=BaseExceptionGroup, ) # What comes out should be 1 or more Cancelled errors that all belong # to this cancel_scope; this is the easiest way to check that raise exc assert cancel_scope.cancelled_caught assert trio.current_time() == 5 # This should have been called already, but just to make sure, since the # exception-handling logic in run_scenario is a bit complicated and the # main thing we care about here is that all the sockets were cleaned up. scenario.check(succeeded=False)