You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
222 lines
6.7 KiB
Python
222 lines
6.7 KiB
Python
2 years ago
|
import io
|
||
|
import socket
|
||
|
import ssl
|
||
|
|
||
|
from ..exceptions import ProxySchemeUnsupported
|
||
|
from ..packages import six
|
||
|
|
||
|
SSL_BLOCKSIZE = 16384
|
||
|
|
||
|
|
||
|
class SSLTransport:
|
||
|
"""
|
||
|
The SSLTransport wraps an existing socket and establishes an SSL connection.
|
||
|
|
||
|
Contrary to Python's implementation of SSLSocket, it allows you to chain
|
||
|
multiple TLS connections together. It's particularly useful if you need to
|
||
|
implement TLS within TLS.
|
||
|
|
||
|
The class supports most of the socket API operations.
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def _validate_ssl_context_for_tls_in_tls(ssl_context):
|
||
|
"""
|
||
|
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
|
||
|
for TLS in TLS.
|
||
|
|
||
|
The only requirement is that the ssl_context provides the 'wrap_bio'
|
||
|
methods.
|
||
|
"""
|
||
|
|
||
|
if not hasattr(ssl_context, "wrap_bio"):
|
||
|
if six.PY2:
|
||
|
raise ProxySchemeUnsupported(
|
||
|
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
|
||
|
"supported on Python 2"
|
||
|
)
|
||
|
else:
|
||
|
raise ProxySchemeUnsupported(
|
||
|
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
|
||
|
"available on non-native SSLContext"
|
||
|
)
|
||
|
|
||
|
def __init__(
|
||
|
self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
|
||
|
):
|
||
|
"""
|
||
|
Create an SSLTransport around socket using the provided ssl_context.
|
||
|
"""
|
||
|
self.incoming = ssl.MemoryBIO()
|
||
|
self.outgoing = ssl.MemoryBIO()
|
||
|
|
||
|
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||
|
self.socket = socket
|
||
|
|
||
|
self.sslobj = ssl_context.wrap_bio(
|
||
|
self.incoming, self.outgoing, server_hostname=server_hostname
|
||
|
)
|
||
|
|
||
|
# Perform initial handshake.
|
||
|
self._ssl_io_loop(self.sslobj.do_handshake)
|
||
|
|
||
|
def __enter__(self):
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, *_):
|
||
|
self.close()
|
||
|
|
||
|
def fileno(self):
|
||
|
return self.socket.fileno()
|
||
|
|
||
|
def read(self, len=1024, buffer=None):
|
||
|
return self._wrap_ssl_read(len, buffer)
|
||
|
|
||
|
def recv(self, len=1024, flags=0):
|
||
|
if flags != 0:
|
||
|
raise ValueError("non-zero flags not allowed in calls to recv")
|
||
|
return self._wrap_ssl_read(len)
|
||
|
|
||
|
def recv_into(self, buffer, nbytes=None, flags=0):
|
||
|
if flags != 0:
|
||
|
raise ValueError("non-zero flags not allowed in calls to recv_into")
|
||
|
if buffer and (nbytes is None):
|
||
|
nbytes = len(buffer)
|
||
|
elif nbytes is None:
|
||
|
nbytes = 1024
|
||
|
return self.read(nbytes, buffer)
|
||
|
|
||
|
def sendall(self, data, flags=0):
|
||
|
if flags != 0:
|
||
|
raise ValueError("non-zero flags not allowed in calls to sendall")
|
||
|
count = 0
|
||
|
with memoryview(data) as view, view.cast("B") as byte_view:
|
||
|
amount = len(byte_view)
|
||
|
while count < amount:
|
||
|
v = self.send(byte_view[count:])
|
||
|
count += v
|
||
|
|
||
|
def send(self, data, flags=0):
|
||
|
if flags != 0:
|
||
|
raise ValueError("non-zero flags not allowed in calls to send")
|
||
|
response = self._ssl_io_loop(self.sslobj.write, data)
|
||
|
return response
|
||
|
|
||
|
def makefile(
|
||
|
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
|
||
|
):
|
||
|
"""
|
||
|
Python's httpclient uses makefile and buffered io when reading HTTP
|
||
|
messages and we need to support it.
|
||
|
|
||
|
This is unfortunately a copy and paste of socket.py makefile with small
|
||
|
changes to point to the socket directly.
|
||
|
"""
|
||
|
if not set(mode) <= {"r", "w", "b"}:
|
||
|
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
|
||
|
|
||
|
writing = "w" in mode
|
||
|
reading = "r" in mode or not writing
|
||
|
assert reading or writing
|
||
|
binary = "b" in mode
|
||
|
rawmode = ""
|
||
|
if reading:
|
||
|
rawmode += "r"
|
||
|
if writing:
|
||
|
rawmode += "w"
|
||
|
raw = socket.SocketIO(self, rawmode)
|
||
|
self.socket._io_refs += 1
|
||
|
if buffering is None:
|
||
|
buffering = -1
|
||
|
if buffering < 0:
|
||
|
buffering = io.DEFAULT_BUFFER_SIZE
|
||
|
if buffering == 0:
|
||
|
if not binary:
|
||
|
raise ValueError("unbuffered streams must be binary")
|
||
|
return raw
|
||
|
if reading and writing:
|
||
|
buffer = io.BufferedRWPair(raw, raw, buffering)
|
||
|
elif reading:
|
||
|
buffer = io.BufferedReader(raw, buffering)
|
||
|
else:
|
||
|
assert writing
|
||
|
buffer = io.BufferedWriter(raw, buffering)
|
||
|
if binary:
|
||
|
return buffer
|
||
|
text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
||
|
text.mode = mode
|
||
|
return text
|
||
|
|
||
|
def unwrap(self):
|
||
|
self._ssl_io_loop(self.sslobj.unwrap)
|
||
|
|
||
|
def close(self):
|
||
|
self.socket.close()
|
||
|
|
||
|
def getpeercert(self, binary_form=False):
|
||
|
return self.sslobj.getpeercert(binary_form)
|
||
|
|
||
|
def version(self):
|
||
|
return self.sslobj.version()
|
||
|
|
||
|
def cipher(self):
|
||
|
return self.sslobj.cipher()
|
||
|
|
||
|
def selected_alpn_protocol(self):
|
||
|
return self.sslobj.selected_alpn_protocol()
|
||
|
|
||
|
def selected_npn_protocol(self):
|
||
|
return self.sslobj.selected_npn_protocol()
|
||
|
|
||
|
def shared_ciphers(self):
|
||
|
return self.sslobj.shared_ciphers()
|
||
|
|
||
|
def compression(self):
|
||
|
return self.sslobj.compression()
|
||
|
|
||
|
def settimeout(self, value):
|
||
|
self.socket.settimeout(value)
|
||
|
|
||
|
def gettimeout(self):
|
||
|
return self.socket.gettimeout()
|
||
|
|
||
|
def _decref_socketios(self):
|
||
|
self.socket._decref_socketios()
|
||
|
|
||
|
def _wrap_ssl_read(self, len, buffer=None):
|
||
|
try:
|
||
|
return self._ssl_io_loop(self.sslobj.read, len, buffer)
|
||
|
except ssl.SSLError as e:
|
||
|
if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
||
|
return 0 # eof, return 0.
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
def _ssl_io_loop(self, func, *args):
|
||
|
"""Performs an I/O loop between incoming/outgoing and the socket."""
|
||
|
should_loop = True
|
||
|
ret = None
|
||
|
|
||
|
while should_loop:
|
||
|
errno = None
|
||
|
try:
|
||
|
ret = func(*args)
|
||
|
except ssl.SSLError as e:
|
||
|
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
|
||
|
# WANT_READ, and WANT_WRITE are expected, others are not.
|
||
|
raise e
|
||
|
errno = e.errno
|
||
|
|
||
|
buf = self.outgoing.read()
|
||
|
self.socket.sendall(buf)
|
||
|
|
||
|
if errno is None:
|
||
|
should_loop = False
|
||
|
elif errno == ssl.SSL_ERROR_WANT_READ:
|
||
|
buf = self.socket.recv(SSL_BLOCKSIZE)
|
||
|
if buf:
|
||
|
self.incoming.write(buf)
|
||
|
else:
|
||
|
self.incoming.write_eof()
|
||
|
return ret
|