You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
303 lines
9.7 KiB
Python
303 lines
9.7 KiB
Python
import os
|
|
import platform
|
|
import socket
|
|
import ssl
|
|
import typing
|
|
|
|
import _ssl # type: ignore[import]
|
|
|
|
from ._ssl_constants import (
|
|
_original_SSLContext,
|
|
_original_super_SSLContext,
|
|
_truststore_SSLContext_dunder_class,
|
|
_truststore_SSLContext_super_class,
|
|
)
|
|
|
|
if platform.system() == "Windows":
|
|
from ._windows import _configure_context, _verify_peercerts_impl
|
|
elif platform.system() == "Darwin":
|
|
from ._macos import _configure_context, _verify_peercerts_impl
|
|
else:
|
|
from ._openssl import _configure_context, _verify_peercerts_impl
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from pip._vendor.typing_extensions import Buffer
|
|
|
|
# From typeshed/stdlib/ssl.pyi
|
|
_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes]
|
|
_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes]
|
|
|
|
|
|
def inject_into_ssl() -> None:
|
|
"""Injects the :class:`truststore.SSLContext` into the ``ssl``
|
|
module by replacing :class:`ssl.SSLContext`.
|
|
"""
|
|
setattr(ssl, "SSLContext", SSLContext)
|
|
# urllib3 holds on to its own reference of ssl.SSLContext
|
|
# so we need to replace that reference too.
|
|
try:
|
|
import pip._vendor.urllib3.util.ssl_ as urllib3_ssl
|
|
|
|
setattr(urllib3_ssl, "SSLContext", SSLContext)
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def extract_from_ssl() -> None:
|
|
"""Restores the :class:`ssl.SSLContext` class to its original state"""
|
|
setattr(ssl, "SSLContext", _original_SSLContext)
|
|
try:
|
|
import pip._vendor.urllib3.util.ssl_ as urllib3_ssl
|
|
|
|
urllib3_ssl.SSLContext = _original_SSLContext
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
class SSLContext(_truststore_SSLContext_super_class): # type: ignore[misc]
|
|
"""SSLContext API that uses system certificates on all platforms"""
|
|
|
|
@property # type: ignore[misc]
|
|
def __class__(self) -> type:
|
|
# Dirty hack to get around isinstance() checks
|
|
# for ssl.SSLContext instances in aiohttp/trustme
|
|
# when using non-CPython implementations.
|
|
return _truststore_SSLContext_dunder_class or SSLContext
|
|
|
|
def __init__(self, protocol: int = None) -> None: # type: ignore[assignment]
|
|
self._ctx = _original_SSLContext(protocol)
|
|
|
|
class TruststoreSSLObject(ssl.SSLObject):
|
|
# This object exists because wrap_bio() doesn't
|
|
# immediately do the handshake so we need to do
|
|
# certificate verifications after SSLObject.do_handshake()
|
|
|
|
def do_handshake(self) -> None:
|
|
ret = super().do_handshake()
|
|
_verify_peercerts(self, server_hostname=self.server_hostname)
|
|
return ret
|
|
|
|
self._ctx.sslobject_class = TruststoreSSLObject
|
|
|
|
def wrap_socket(
|
|
self,
|
|
sock: socket.socket,
|
|
server_side: bool = False,
|
|
do_handshake_on_connect: bool = True,
|
|
suppress_ragged_eofs: bool = True,
|
|
server_hostname: str | None = None,
|
|
session: ssl.SSLSession | None = None,
|
|
) -> ssl.SSLSocket:
|
|
# Use a context manager here because the
|
|
# inner SSLContext holds on to our state
|
|
# but also does the actual handshake.
|
|
with _configure_context(self._ctx):
|
|
ssl_sock = self._ctx.wrap_socket(
|
|
sock,
|
|
server_side=server_side,
|
|
server_hostname=server_hostname,
|
|
do_handshake_on_connect=do_handshake_on_connect,
|
|
suppress_ragged_eofs=suppress_ragged_eofs,
|
|
session=session,
|
|
)
|
|
try:
|
|
_verify_peercerts(ssl_sock, server_hostname=server_hostname)
|
|
except Exception:
|
|
ssl_sock.close()
|
|
raise
|
|
return ssl_sock
|
|
|
|
def wrap_bio(
|
|
self,
|
|
incoming: ssl.MemoryBIO,
|
|
outgoing: ssl.MemoryBIO,
|
|
server_side: bool = False,
|
|
server_hostname: str | None = None,
|
|
session: ssl.SSLSession | None = None,
|
|
) -> ssl.SSLObject:
|
|
with _configure_context(self._ctx):
|
|
ssl_obj = self._ctx.wrap_bio(
|
|
incoming,
|
|
outgoing,
|
|
server_hostname=server_hostname,
|
|
server_side=server_side,
|
|
session=session,
|
|
)
|
|
return ssl_obj
|
|
|
|
def load_verify_locations(
|
|
self,
|
|
cafile: str | bytes | os.PathLike[str] | os.PathLike[bytes] | None = None,
|
|
capath: str | bytes | os.PathLike[str] | os.PathLike[bytes] | None = None,
|
|
cadata: typing.Union[str, "Buffer", None] = None,
|
|
) -> None:
|
|
return self._ctx.load_verify_locations(
|
|
cafile=cafile, capath=capath, cadata=cadata
|
|
)
|
|
|
|
def load_cert_chain(
|
|
self,
|
|
certfile: _StrOrBytesPath,
|
|
keyfile: _StrOrBytesPath | None = None,
|
|
password: _PasswordType | None = None,
|
|
) -> None:
|
|
return self._ctx.load_cert_chain(
|
|
certfile=certfile, keyfile=keyfile, password=password
|
|
)
|
|
|
|
def load_default_certs(
|
|
self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH
|
|
) -> None:
|
|
return self._ctx.load_default_certs(purpose)
|
|
|
|
def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None:
|
|
return self._ctx.set_alpn_protocols(alpn_protocols)
|
|
|
|
def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None:
|
|
return self._ctx.set_npn_protocols(npn_protocols)
|
|
|
|
def set_ciphers(self, __cipherlist: str) -> None:
|
|
return self._ctx.set_ciphers(__cipherlist)
|
|
|
|
def get_ciphers(self) -> typing.Any:
|
|
return self._ctx.get_ciphers()
|
|
|
|
def session_stats(self) -> dict[str, int]:
|
|
return self._ctx.session_stats()
|
|
|
|
def cert_store_stats(self) -> dict[str, int]:
|
|
raise NotImplementedError()
|
|
|
|
@typing.overload
|
|
def get_ca_certs(
|
|
self, binary_form: typing.Literal[False] = ...
|
|
) -> list[typing.Any]:
|
|
...
|
|
|
|
@typing.overload
|
|
def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]:
|
|
...
|
|
|
|
@typing.overload
|
|
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any:
|
|
...
|
|
|
|
def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]:
|
|
raise NotImplementedError()
|
|
|
|
@property
|
|
def check_hostname(self) -> bool:
|
|
return self._ctx.check_hostname
|
|
|
|
@check_hostname.setter
|
|
def check_hostname(self, value: bool) -> None:
|
|
self._ctx.check_hostname = value
|
|
|
|
@property
|
|
def hostname_checks_common_name(self) -> bool:
|
|
return self._ctx.hostname_checks_common_name
|
|
|
|
@hostname_checks_common_name.setter
|
|
def hostname_checks_common_name(self, value: bool) -> None:
|
|
self._ctx.hostname_checks_common_name = value
|
|
|
|
@property
|
|
def keylog_filename(self) -> str:
|
|
return self._ctx.keylog_filename
|
|
|
|
@keylog_filename.setter
|
|
def keylog_filename(self, value: str) -> None:
|
|
self._ctx.keylog_filename = value
|
|
|
|
@property
|
|
def maximum_version(self) -> ssl.TLSVersion:
|
|
return self._ctx.maximum_version
|
|
|
|
@maximum_version.setter
|
|
def maximum_version(self, value: ssl.TLSVersion) -> None:
|
|
_original_super_SSLContext.maximum_version.__set__( # type: ignore[attr-defined]
|
|
self._ctx, value
|
|
)
|
|
|
|
@property
|
|
def minimum_version(self) -> ssl.TLSVersion:
|
|
return self._ctx.minimum_version
|
|
|
|
@minimum_version.setter
|
|
def minimum_version(self, value: ssl.TLSVersion) -> None:
|
|
_original_super_SSLContext.minimum_version.__set__( # type: ignore[attr-defined]
|
|
self._ctx, value
|
|
)
|
|
|
|
@property
|
|
def options(self) -> ssl.Options:
|
|
return self._ctx.options
|
|
|
|
@options.setter
|
|
def options(self, value: ssl.Options) -> None:
|
|
_original_super_SSLContext.options.__set__( # type: ignore[attr-defined]
|
|
self._ctx, value
|
|
)
|
|
|
|
@property
|
|
def post_handshake_auth(self) -> bool:
|
|
return self._ctx.post_handshake_auth
|
|
|
|
@post_handshake_auth.setter
|
|
def post_handshake_auth(self, value: bool) -> None:
|
|
self._ctx.post_handshake_auth = value
|
|
|
|
@property
|
|
def protocol(self) -> ssl._SSLMethod:
|
|
return self._ctx.protocol
|
|
|
|
@property
|
|
def security_level(self) -> int:
|
|
return self._ctx.security_level
|
|
|
|
@property
|
|
def verify_flags(self) -> ssl.VerifyFlags:
|
|
return self._ctx.verify_flags
|
|
|
|
@verify_flags.setter
|
|
def verify_flags(self, value: ssl.VerifyFlags) -> None:
|
|
_original_super_SSLContext.verify_flags.__set__( # type: ignore[attr-defined]
|
|
self._ctx, value
|
|
)
|
|
|
|
@property
|
|
def verify_mode(self) -> ssl.VerifyMode:
|
|
return self._ctx.verify_mode
|
|
|
|
@verify_mode.setter
|
|
def verify_mode(self, value: ssl.VerifyMode) -> None:
|
|
_original_super_SSLContext.verify_mode.__set__( # type: ignore[attr-defined]
|
|
self._ctx, value
|
|
)
|
|
|
|
|
|
def _verify_peercerts(
|
|
sock_or_sslobj: ssl.SSLSocket | ssl.SSLObject, server_hostname: str | None
|
|
) -> None:
|
|
"""
|
|
Verifies the peer certificates from an SSLSocket or SSLObject
|
|
against the certificates in the OS trust store.
|
|
"""
|
|
sslobj: ssl.SSLObject = sock_or_sslobj # type: ignore[assignment]
|
|
try:
|
|
while not hasattr(sslobj, "get_unverified_chain"):
|
|
sslobj = sslobj._sslobj # type: ignore[attr-defined]
|
|
except AttributeError:
|
|
pass
|
|
|
|
# SSLObject.get_unverified_chain() returns 'None'
|
|
# if the peer sends no certificates. This is common
|
|
# for the server-side scenario.
|
|
unverified_chain: typing.Sequence[_ssl.Certificate] = (
|
|
sslobj.get_unverified_chain() or () # type: ignore[attr-defined]
|
|
)
|
|
cert_bytes = [cert.public_bytes(_ssl.ENCODING_DER) for cert in unverified_chain]
|
|
_verify_peercerts_impl(
|
|
sock_or_sslobj.context, cert_bytes, server_hostname=server_hostname
|
|
)
|