You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

729 lines
28 KiB
Python

# Wrapper module for _ssl. Written by Bill Janssen.
# Ported to gevent by Denis Bilenko.
"""SSL wrapper for socket objects on Python 2.7.9 and above.
For the documentation, refer to :mod:`ssl` module manual.
This module implements cooperative SSL socket wrappers.
"""
from __future__ import absolute_import
# Our import magic sadly makes this warning useless
# pylint: disable=undefined-variable
# pylint: disable=too-many-instance-attributes,too-many-locals,too-many-statements,too-many-branches
# pylint: disable=arguments-differ,too-many-public-methods
import ssl as __ssl__
_ssl = __ssl__._ssl # pylint:disable=no-member
import errno
from gevent._socket2 import socket
from gevent.socket import timeout_default
from gevent.socket import create_connection
from gevent.socket import error as socket_error
from gevent.socket import timeout as _socket_timeout
from gevent._compat import PYPY
from gevent._util import copy_globals
__implements__ = [
'SSLContext',
'SSLSocket',
'wrap_socket',
'get_server_certificate',
'create_default_context',
'_create_unverified_context',
'_create_default_https_context',
'_create_stdlib_context',
'_fileobject',
]
# Import all symbols from Python's ssl.py, except those that we are implementing
# and "private" symbols.
__imports__ = copy_globals(__ssl__, globals(),
# SSLSocket *must* subclass gevent.socket.socket; see issue 597 and 801
names_to_ignore=__implements__ + ['socket', 'create_connection'],
dunder_names_to_keep=())
try:
_delegate_methods
except NameError: # PyPy doesn't expose this detail
_delegate_methods = ('recv', 'recvfrom', 'recv_into', 'recvfrom_into', 'send', 'sendto')
__all__ = __implements__ + __imports__
if 'namedtuple' in __all__:
__all__.remove('namedtuple')
# See notes in _socket2.py. Python 3 returns much nicer
# `io` object wrapped around a SocketIO class.
if hasattr(__ssl__, '_fileobject'):
assert not hasattr(__ssl__._fileobject, '__enter__') # pylint:disable=no-member
class _fileobject(getattr(__ssl__, '_fileobject', object)): # pylint:disable=no-member
def __enter__(self):
return self
def __exit__(self, *args):
# pylint:disable=no-member
if not self.closed:
self.close()
orig_SSLContext = __ssl__.SSLContext # pylint: disable=no-member
class SSLContext(orig_SSLContext):
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None):
return SSLSocket(sock=sock, server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
server_hostname=server_hostname,
_context=self)
def create_default_context(purpose=Purpose.SERVER_AUTH, cafile=None,
capath=None, cadata=None):
"""Create a SSLContext object with default settings.
NOTE: The protocol and settings may change anytime without prior
deprecation. The values represent a fair balance between maximum
compatibility and security.
"""
if not isinstance(purpose, _ASN1Object):
raise TypeError(purpose)
context = SSLContext(PROTOCOL_SSLv23)
# SSLv2 considered harmful.
context.options |= OP_NO_SSLv2 # pylint:disable=no-member
# SSLv3 has problematic security and is only required for really old
# clients such as IE6 on Windows XP
context.options |= OP_NO_SSLv3 # pylint:disable=no-member
# disable compression to prevent CRIME attacks (OpenSSL 1.0+)
context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0) # pylint:disable=no-member
if purpose == Purpose.SERVER_AUTH:
# verify certs and host name in client mode
context.verify_mode = CERT_REQUIRED
context.check_hostname = True # pylint: disable=attribute-defined-outside-init
elif purpose == Purpose.CLIENT_AUTH:
# Prefer the server's ciphers by default so that we get stronger
# encryption
context.options |= getattr(_ssl, "OP_CIPHER_SERVER_PREFERENCE", 0) # pylint:disable=no-member
# Use single use keys in order to improve forward secrecy
context.options |= getattr(_ssl, "OP_SINGLE_DH_USE", 0) # pylint:disable=no-member
context.options |= getattr(_ssl, "OP_SINGLE_ECDH_USE", 0) # pylint:disable=no-member
# disallow ciphers with known vulnerabilities
context.set_ciphers(_RESTRICTED_SERVER_CIPHERS)
if cafile or capath or cadata:
context.load_verify_locations(cafile, capath, cadata)
elif context.verify_mode != CERT_NONE:
# no explicit cafile, capath or cadata but the verify mode is
# CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
# root CA certificates for the given purpose. This may fail silently.
context.load_default_certs(purpose)
return context
def _create_unverified_context(protocol=PROTOCOL_SSLv23, cert_reqs=None,
check_hostname=False, purpose=Purpose.SERVER_AUTH,
certfile=None, keyfile=None,
cafile=None, capath=None, cadata=None):
"""Create a SSLContext object for Python stdlib modules
All Python stdlib modules shall use this function to create SSLContext
objects in order to keep common settings in one place. The configuration
is less restrict than create_default_context()'s to increase backward
compatibility.
"""
if not isinstance(purpose, _ASN1Object):
raise TypeError(purpose)
context = SSLContext(protocol)
# SSLv2 considered harmful.
context.options |= OP_NO_SSLv2 # pylint:disable=no-member
# SSLv3 has problematic security and is only required for really old
# clients such as IE6 on Windows XP
context.options |= OP_NO_SSLv3 # pylint:disable=no-member
if cert_reqs is not None:
context.verify_mode = cert_reqs
context.check_hostname = check_hostname # pylint: disable=attribute-defined-outside-init
if keyfile and not certfile:
raise ValueError("certfile must be specified")
if certfile or keyfile:
context.load_cert_chain(certfile, keyfile)
# load CA root certs
if cafile or capath or cadata:
context.load_verify_locations(cafile, capath, cadata)
elif context.verify_mode != CERT_NONE:
# no explicit cafile, capath or cadata but the verify mode is
# CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
# root CA certificates for the given purpose. This may fail silently.
context.load_default_certs(purpose)
return context
# Used by http.client if no context is explicitly passed.
_create_default_https_context = create_default_context
# Backwards compatibility alias, even though it's not a public name.
_create_stdlib_context = _create_unverified_context
class SSLSocket(socket):
"""
gevent `ssl.SSLSocket <https://docs.python.org/2/library/ssl.html#ssl-sockets>`_
for Pythons >= 2.7.9 but less than 3.
"""
def __init__(self, sock=None, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
server_hostname=None,
_context=None):
# fileno is ignored
# pylint: disable=unused-argument
if _context:
self._context = _context
else:
if server_side and not certfile:
raise ValueError("certfile must be specified for server-side "
"operations")
if keyfile and not certfile:
raise ValueError("certfile must be specified")
if certfile and not keyfile:
keyfile = certfile
self._context = SSLContext(ssl_version)
self._context.verify_mode = cert_reqs
if ca_certs:
self._context.load_verify_locations(ca_certs)
if certfile:
self._context.load_cert_chain(certfile, keyfile)
if npn_protocols:
self._context.set_npn_protocols(npn_protocols)
if ciphers:
self._context.set_ciphers(ciphers)
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
self.ssl_version = ssl_version
self.ca_certs = ca_certs
self.ciphers = ciphers
# Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
# mixed in.
if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
raise NotImplementedError("only stream sockets are supported")
if PYPY:
socket.__init__(self, _sock=sock)
sock._drop()
else:
# CPython: XXX: Must pass the underlying socket, not our
# potential wrapper; test___example_servers fails the SSL test
# with a client-side EOF error. (Why?)
socket.__init__(self, _sock=sock._sock)
# The initializer for socket overrides the methods send(), recv(), etc.
# in the instance, which we don't need -- but we want to provide the
# methods defined in SSLSocket.
for attr in _delegate_methods:
try:
delattr(self, attr)
except AttributeError:
pass
if server_side and server_hostname:
raise ValueError("server_hostname can only be specified "
"in client mode")
if self._context.check_hostname and not server_hostname:
raise ValueError("check_hostname requires server_hostname")
self.server_side = server_side
self.server_hostname = server_hostname
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
self.settimeout(sock.gettimeout())
# See if we are connected
try:
self.getpeername()
except socket_error as e:
if e.errno != errno.ENOTCONN:
raise
connected = False
else:
connected = True
self._makefile_refs = 0
self._closed = False
self._sslobj = None
self._connected = connected
if connected:
# create the SSL object
try:
self._sslobj = self._context._wrap_socket(self._sock, server_side,
server_hostname, ssl_sock=self)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
# non-blocking
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
self.do_handshake()
except socket_error as x:
self.close()
raise x
@property
def context(self):
return self._context
@context.setter
def context(self, ctx):
self._context = ctx
self._sslobj.context = ctx
def dup(self):
raise NotImplementedError("Can't dup() %s instances" %
self.__class__.__name__)
def _checkClosed(self, msg=None):
# raise an exception here if you wish to check for spurious closes
pass
def _check_connected(self):
if not self._connected:
# getpeername() will raise ENOTCONN if the socket is really
# not connected; note that we can be connected even without
# _connected being set, e.g. if connect() first returned
# EAGAIN.
self.getpeername()
def read(self, len=1024, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
self._checkClosed()
while 1:
if not self._sslobj:
raise ValueError("Read on closed or unwrapped SSL socket.")
if len == 0:
return b'' if buffer is None else 0
if len < 0 and buffer is None:
# This is handled natively in python 2.7.12+
raise ValueError("Negative read length")
try:
if buffer is not None:
return self._sslobj.read(len, buffer)
return self._sslobj.read(len or 1024)
except SSLWantReadError:
if self.timeout == 0.0:
raise
self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout)
except SSLWantWriteError:
if self.timeout == 0.0:
raise
# note: using _SSLErrorReadTimeout rather than _SSLErrorWriteTimeout below is intentional
self._wait(self._write_event, timeout_exc=_SSLErrorReadTimeout)
except SSLError as ex:
if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None:
return 0
return b''
raise
def write(self, data):
"""Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted."""
self._checkClosed()
while 1:
if not self._sslobj:
raise ValueError("Write on closed or unwrapped SSL socket.")
try:
return self._sslobj.write(data)
except SSLError as ex:
if ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
self._wait(self._read_event, timeout_exc=_SSLErrorWriteTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout)
else:
raise
def getpeercert(self, binary_form=False):
"""Returns a formatted version of the data in the
certificate provided by the other end of the SSL channel.
Return None if no certificate was provided, {} if a
certificate was provided, but not validated."""
self._checkClosed()
self._check_connected()
return self._sslobj.peer_certificate(binary_form)
def selected_npn_protocol(self):
self._checkClosed()
if not self._sslobj or not _ssl.HAS_NPN:
return None
return self._sslobj.selected_npn_protocol()
if hasattr(_ssl, 'HAS_ALPN'):
# 2.7.10+
def selected_alpn_protocol(self):
self._checkClosed()
if not self._sslobj or not _ssl.HAS_ALPN: # pylint:disable=no-member
return None
return self._sslobj.selected_alpn_protocol()
def cipher(self):
self._checkClosed()
if not self._sslobj:
return None
return self._sslobj.cipher()
def compression(self):
self._checkClosed()
if not self._sslobj:
return None
return self._sslobj.compression()
def __check_flags(self, meth, flags):
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to %s on %s" %
(meth, self.__class__))
def send(self, data, flags=0, timeout=timeout_default):
self._checkClosed()
self.__check_flags('send', flags)
if timeout is timeout_default:
timeout = self.timeout
if not self._sslobj:
return socket.send(self, data, flags, timeout)
while True:
try:
return self._sslobj.write(data)
except SSLWantReadError:
if self.timeout == 0.0:
return 0
self._wait(self._read_event)
except SSLWantWriteError:
if self.timeout == 0.0:
return 0
self._wait(self._write_event)
def sendto(self, data, flags_or_addr, addr=None):
self._checkClosed()
if self._sslobj:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
if addr is None:
return socket.sendto(self, data, flags_or_addr)
return socket.sendto(self, data, flags_or_addr, addr)
def sendmsg(self, *args, **kwargs):
# Ensure programs don't send data unencrypted if they try to
# use this method.
raise NotImplementedError("sendmsg not allowed on instances of %s" %
self.__class__)
def sendall(self, data, flags=0):
self._checkClosed()
self.__check_flags('sendall', flags)
try:
socket.sendall(self, data)
except _socket_timeout as ex:
if self.timeout == 0.0:
# Python 2 simply *hangs* in this case, which is bad, but
# Python 3 raises SSLWantWriteError. We do the same.
raise SSLWantWriteError("The operation did not complete (write)")
# Convert the socket.timeout back to the sslerror
raise SSLError(*ex.args)
def recv(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
self.__class__)
if buflen == 0:
return b''
return self.read(buflen)
return socket.recv(self, buflen, flags)
def recv_into(self, buffer, nbytes=None, flags=0):
self._checkClosed()
if buffer is not None and (nbytes is None):
# Fix for python bug #23804: bool(bytearray()) is False,
# but we should read 0 bytes.
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s" %
self.__class__)
return self.read(nbytes, buffer)
return socket.recv_into(self, buffer, nbytes, flags)
def recvfrom(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
return socket.recvfrom(self, buflen, flags)
def recvfrom_into(self, buffer, nbytes=None, flags=0):
self._checkClosed()
if self._sslobj:
raise ValueError("recvfrom_into not allowed on instances of %s" %
self.__class__)
return socket.recvfrom_into(self, buffer, nbytes, flags)
def recvmsg(self, *args, **kwargs):
raise NotImplementedError("recvmsg not allowed on instances of %s" %
self.__class__)
def recvmsg_into(self, *args, **kwargs):
raise NotImplementedError("recvmsg_into not allowed on instances of "
"%s" % self.__class__)
def pending(self):
self._checkClosed()
if self._sslobj:
return self._sslobj.pending()
return 0
def shutdown(self, how):
self._checkClosed()
self._sslobj = None
socket.shutdown(self, how)
def close(self):
if self._makefile_refs < 1:
self._sslobj = None
socket.close(self)
else:
self._makefile_refs -= 1
if PYPY:
def _reuse(self):
self._makefile_refs += 1
def _drop(self):
if self._makefile_refs < 1:
self.close()
else:
self._makefile_refs -= 1
def _sslobj_shutdown(self):
while True:
try:
return self._sslobj.shutdown()
except SSLError as ex:
if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return ''
if ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout)
else:
raise
def unwrap(self):
if not self._sslobj:
raise ValueError("No SSL wrapper around " + str(self))
s = self._sslobj_shutdown()
self._sslobj = None
# match _ssl2; critical to drop/reuse here on PyPy
# XXX: _ssl3 returns an SSLSocket. Is that what the standard lib does on
# Python 2? Should we do that?
return socket(_sock=s)
def _real_close(self):
self._sslobj = None
socket._real_close(self) # pylint: disable=no-member
def do_handshake(self):
"""Perform a TLS/SSL handshake."""
self._check_connected()
while True:
try:
self._sslobj.do_handshake()
break
except SSLWantReadError:
if self.timeout == 0.0:
raise
self._wait(self._read_event, timeout_exc=_SSLErrorHandshakeTimeout)
except SSLWantWriteError:
if self.timeout == 0.0:
raise
self._wait(self._write_event, timeout_exc=_SSLErrorHandshakeTimeout)
if self._context.check_hostname:
if not self.server_hostname:
raise ValueError("check_hostname needs server_hostname "
"argument")
match_hostname(self.getpeercert(), self.server_hostname)
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
self._sslobj = self._context._wrap_socket(self._sock, False, self.server_hostname, ssl_sock=self)
try:
if connect_ex:
rc = socket.connect_ex(self, addr)
else:
rc = None
socket.connect(self, addr)
if not rc:
self._connected = True
if self.do_handshake_on_connect:
self.do_handshake()
return rc
except socket_error:
self._sslobj = None
raise
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
self._real_connect(addr, False)
def connect_ex(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
return self._real_connect(addr, True)
def accept(self):
"""Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
newsock._drop_events_and_close(closefd=False) # Why, again?
newsock = self._context.wrap_socket(newsock,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs,
server_side=True)
return newsock, addr
def makefile(self, mode='r', bufsize=-1):
"""Make and return a file-like object that
works with the SSL connection. Just use the code
from the socket module."""
if not PYPY:
self._makefile_refs += 1
# close=True so as to decrement the reference count when done with
# the file-like object.
return _fileobject(self, mode, bufsize, close=True)
def get_channel_binding(self, cb_type="tls-unique"):
"""Get channel binding data for current connection. Raise ValueError
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake).
"""
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError("Unsupported channel binding type")
if cb_type != "tls-unique":
raise NotImplementedError(
"{0} channel binding type not implemented"
.format(cb_type))
if self._sslobj is None:
return None
return self._sslobj.tls_unique_cb()
def version(self):
"""
Return a string identifying the protocol version used by the
current SSL channel, or None if there is no established channel.
"""
if self._sslobj is None:
return None
return self._sslobj.version()
if PYPY or not hasattr(SSLSocket, 'timeout'):
# PyPy (and certain versions of CPython) doesn't have a direct
# 'timeout' property on raw sockets, because that's not part of
# the documented specification. We may wind up wrapping a raw
# socket (when ssl is used with PyWSGI) or a gevent socket, which
# does have a read/write timeout property as an alias for
# get/settimeout, so make sure that's always the case because
# pywsgi can depend on that.
SSLSocket.timeout = property(lambda self: self.gettimeout(),
lambda self, value: self.settimeout(value))
_SSLErrorReadTimeout = SSLError('The read operation timed out')
_SSLErrorWriteTimeout = SSLError('The write operation timed out')
_SSLErrorHandshakeTimeout = SSLError('The handshake operation timed out')
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None):
return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers)
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None):
"""Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it.
If 'ssl_version' is specified, use it in the connection attempt."""
_, _ = addr
if ca_certs is not None:
cert_reqs = CERT_REQUIRED
else:
cert_reqs = CERT_NONE
context = _create_stdlib_context(ssl_version,
cert_reqs=cert_reqs,
cafile=ca_certs)
with closing(create_connection(addr)) as sock:
with closing(context.wrap_socket(sock)) as sslsock:
dercert = sslsock.getpeercert(True)
return DER_cert_to_PEM_cert(dercert)