You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.9 KiB
Python

import pytest
from functools import partial
import attr
import trio
from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP
import trio.testing
from .test_ssl import client_ctx, SERVER_CTX
from .._highlevel_ssl_helpers import (
open_ssl_over_tcp_stream,
open_ssl_over_tcp_listeners,
serve_ssl_over_tcp,
)
async def echo_handler(stream):
async with stream:
try:
while True:
data = await stream.receive_some(10000)
if not data:
break
await stream.send_all(data)
except trio.BrokenResourceError:
pass
# Resolver that always returns the given sockaddr, no matter what host/port
# you ask for.
@attr.s
class FakeHostnameResolver(trio.abc.HostnameResolver):
sockaddr = attr.ib()
async def getaddrinfo(self, *args):
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
async def getnameinfo(self, *args): # pragma: no cover
raise NotImplementedError
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# noqa is needed because flake8 doesn't understand how pytest fixtures work.
async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811
async with trio.open_nursery() as nursery:
(listener,) = await nursery.start(
partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
)
async with listener:
sockaddr = listener.transport_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
trio.socket.set_custom_hostname_resolver(hostname_resolver)
# We don't have the right trust set up
# (checks that ssl_context=None is doing some validation)
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()
# We have the trust but not the hostname
# (checks custom ssl_context + hostname checking)
stream = await open_ssl_over_tcp_stream(
"xyzzy.example.org", 80, ssl_context=client_ctx
)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()
# This one should work!
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org", 80, ssl_context=client_ctx
)
async with stream:
assert isinstance(stream, trio.SSLStream)
assert stream.server_hostname == "trio-test-1.example.org"
await stream.send_all(b"x")
assert await stream.receive_some(1) == b"x"
# Check https_compatible settings are being passed through
assert not stream._https_compatible
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org",
80,
ssl_context=client_ctx,
https_compatible=True,
# also, smoke test happy_eyeballs_delay
happy_eyeballs_delay=1,
)
async with stream:
assert stream._https_compatible
# Stop the echo server
nursery.cancel_scope.cancel()
async def test_open_ssl_over_tcp_listeners():
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
async with listener:
assert isinstance(listener, trio.SSLListener)
tl = listener.transport_listener
assert isinstance(tl, trio.SocketListener)
assert tl.socket.getsockname()[0] == "127.0.0.1"
assert not listener._https_compatible
(listener,) = await open_ssl_over_tcp_listeners(
0, SERVER_CTX, host="127.0.0.1", https_compatible=True
)
async with listener:
assert listener._https_compatible