You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
2.6 KiB
Python

import pytest
import attr
from ..abc import SendStream, ReceiveStream
from .._highlevel_generic import StapledStream
@attr.s
class RecordSendStream(SendStream):
record = attr.ib(factory=list)
async def send_all(self, data):
self.record.append(("send_all", data))
async def wait_send_all_might_not_block(self):
self.record.append("wait_send_all_might_not_block")
async def aclose(self):
self.record.append("aclose")
@attr.s
class RecordReceiveStream(ReceiveStream):
record = attr.ib(factory=list)
async def receive_some(self, max_bytes=None):
self.record.append(("receive_some", max_bytes))
async def aclose(self):
self.record.append("aclose")
async def test_StapledStream():
send_stream = RecordSendStream()
receive_stream = RecordReceiveStream()
stapled = StapledStream(send_stream, receive_stream)
assert stapled.send_stream is send_stream
assert stapled.receive_stream is receive_stream
await stapled.send_all(b"foo")
await stapled.wait_send_all_might_not_block()
assert send_stream.record == [
("send_all", b"foo"),
"wait_send_all_might_not_block",
]
send_stream.record.clear()
await stapled.send_eof()
assert send_stream.record == ["aclose"]
send_stream.record.clear()
async def fake_send_eof():
send_stream.record.append("send_eof")
send_stream.send_eof = fake_send_eof
await stapled.send_eof()
assert send_stream.record == ["send_eof"]
send_stream.record.clear()
assert receive_stream.record == []
await stapled.receive_some(1234)
assert receive_stream.record == [("receive_some", 1234)]
assert send_stream.record == []
receive_stream.record.clear()
await stapled.aclose()
assert receive_stream.record == ["aclose"]
assert send_stream.record == ["aclose"]
async def test_StapledStream_with_erroring_close():
# Make sure that if one of the aclose methods errors out, then the other
# one still gets called.
class BrokenSendStream(RecordSendStream):
async def aclose(self):
await super().aclose()
raise ValueError
class BrokenReceiveStream(RecordReceiveStream):
async def aclose(self):
await super().aclose()
raise ValueError
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
with pytest.raises(ValueError) as excinfo:
await stapled.aclose()
assert isinstance(excinfo.value.__context__, ValueError)
assert stapled.send_stream.record == ["aclose"]
assert stapled.receive_stream.record == ["aclose"]