You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
444 lines
16 KiB
Python
444 lines
16 KiB
Python
"""Base implementation of 0MQ authentication."""
|
|
|
|
# Copyright (C) PyZMQ Developers
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import zmq
|
|
from zmq.error import _check_version
|
|
from zmq.utils import z85
|
|
|
|
from .certs import load_certificates
|
|
|
|
CURVE_ALLOW_ANY = '*'
|
|
VERSION = b'1.0'
|
|
|
|
|
|
class Authenticator:
|
|
"""Implementation of ZAP authentication for zmq connections.
|
|
|
|
This authenticator class does not register with an event loop. As a result,
|
|
you will need to manually call `handle_zap_message`::
|
|
|
|
auth = zmq.Authenticator()
|
|
auth.allow("127.0.0.1")
|
|
auth.start()
|
|
while True:
|
|
auth.handle_zap_msg(auth.zap_socket.recv_multipart())
|
|
|
|
Alternatively, you can register `auth.zap_socket` with a poller.
|
|
|
|
Since many users will want to run ZAP in a way that does not block the
|
|
main thread, other authentication classes (such as :mod:`zmq.auth.thread`)
|
|
are provided.
|
|
|
|
Note:
|
|
|
|
- libzmq provides four levels of security: default NULL (which the Authenticator does
|
|
not see), and authenticated NULL, PLAIN, CURVE, and GSSAPI, which the Authenticator can see.
|
|
- until you add policies, all incoming NULL connections are allowed.
|
|
(classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
|
|
- GSSAPI requires no configuration.
|
|
"""
|
|
|
|
context: "zmq.Context"
|
|
encoding: str
|
|
allow_any: bool
|
|
credentials_providers: Dict[str, Any]
|
|
zap_socket: "zmq.Socket"
|
|
whitelist: Set[str]
|
|
blacklist: Set[str]
|
|
passwords: Dict[str, Dict[str, str]]
|
|
certs: Dict[str, Dict[bytes, Any]]
|
|
log: Any
|
|
|
|
def __init__(
|
|
self,
|
|
context: Optional["zmq.Context"] = None,
|
|
encoding: str = 'utf-8',
|
|
log: Any = None,
|
|
):
|
|
_check_version((4, 0), "security")
|
|
self.context = context or zmq.Context.instance()
|
|
self.encoding = encoding
|
|
self.allow_any = False
|
|
self.credentials_providers = {}
|
|
self.zap_socket = None # type: ignore
|
|
self.whitelist = set()
|
|
self.blacklist = set()
|
|
# passwords is a dict keyed by domain and contains values
|
|
# of dicts with username:password pairs.
|
|
self.passwords = {}
|
|
# certs is dict keyed by domain and contains values
|
|
# of dicts keyed by the public keys from the specified location.
|
|
self.certs = {}
|
|
self.log = log or logging.getLogger('zmq.auth')
|
|
|
|
def start(self) -> None:
|
|
"""Create and bind the ZAP socket"""
|
|
self.zap_socket = self.context.socket(zmq.REP)
|
|
self.zap_socket.linger = 1
|
|
self.zap_socket.bind("inproc://zeromq.zap.01")
|
|
self.log.debug("Starting")
|
|
|
|
def stop(self) -> None:
|
|
"""Close the ZAP socket"""
|
|
if self.zap_socket:
|
|
self.zap_socket.close()
|
|
self.zap_socket = None # type: ignore
|
|
|
|
def allow(self, *addresses: str) -> None:
|
|
"""Allow (whitelist) IP address(es).
|
|
|
|
Connections from addresses not in the whitelist will be rejected.
|
|
|
|
- For NULL, all clients from this address will be accepted.
|
|
- For real auth setups, they will be allowed to continue with authentication.
|
|
|
|
whitelist is mutually exclusive with blacklist.
|
|
"""
|
|
if self.blacklist:
|
|
raise ValueError("Only use a whitelist or a blacklist, not both")
|
|
self.log.debug("Allowing %s", ','.join(addresses))
|
|
self.whitelist.update(addresses)
|
|
|
|
def deny(self, *addresses: str) -> None:
|
|
"""Deny (blacklist) IP address(es).
|
|
|
|
Addresses not in the blacklist will be allowed to continue with authentication.
|
|
|
|
Blacklist is mutually exclusive with whitelist.
|
|
"""
|
|
if self.whitelist:
|
|
raise ValueError("Only use a whitelist or a blacklist, not both")
|
|
self.log.debug("Denying %s", ','.join(addresses))
|
|
self.blacklist.update(addresses)
|
|
|
|
def configure_plain(
|
|
self, domain: str = '*', passwords: Dict[str, str] = None
|
|
) -> None:
|
|
"""Configure PLAIN authentication for a given domain.
|
|
|
|
PLAIN authentication uses a plain-text password file.
|
|
To cover all domains, use "*".
|
|
You can modify the password file at any time; it is reloaded automatically.
|
|
"""
|
|
if passwords:
|
|
self.passwords[domain] = passwords
|
|
self.log.debug("Configure plain: %s", domain)
|
|
|
|
def configure_curve(
|
|
self, domain: str = '*', location: Union[str, os.PathLike] = "."
|
|
) -> None:
|
|
"""Configure CURVE authentication for a given domain.
|
|
|
|
CURVE authentication uses a directory that holds all public client certificates,
|
|
i.e. their public keys.
|
|
|
|
To cover all domains, use "*".
|
|
|
|
You can add and remove certificates in that directory at any time. configure_curve must be called
|
|
every time certificates are added or removed, in order to update the Authenticator's state
|
|
|
|
To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
|
|
"""
|
|
# If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
|
|
# treat location as a directory that holds the certificates.
|
|
self.log.debug("Configure curve: %s[%s]", domain, location)
|
|
if location == CURVE_ALLOW_ANY:
|
|
self.allow_any = True
|
|
else:
|
|
self.allow_any = False
|
|
try:
|
|
self.certs[domain] = load_certificates(location)
|
|
except Exception as e:
|
|
self.log.error("Failed to load CURVE certs from %s: %s", location, e)
|
|
|
|
def configure_curve_callback(
|
|
self, domain: str = '*', credentials_provider: Any = None
|
|
) -> None:
|
|
"""Configure CURVE authentication for a given domain.
|
|
|
|
CURVE authentication using a callback function validating
|
|
the client public key according to a custom mechanism, e.g. checking the
|
|
key against records in a db. credentials_provider is an object of a class which
|
|
implements a callback method accepting two parameters (domain and key), e.g.::
|
|
|
|
class CredentialsProvider(object):
|
|
|
|
def __init__(self):
|
|
...e.g. db connection
|
|
|
|
def callback(self, domain, key):
|
|
valid = ...lookup key and/or domain in db
|
|
if valid:
|
|
logging.info('Authorizing: {0}, {1}'.format(domain, key))
|
|
return True
|
|
else:
|
|
logging.warning('NOT Authorizing: {0}, {1}'.format(domain, key))
|
|
return False
|
|
|
|
To cover all domains, use "*".
|
|
|
|
To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
|
|
"""
|
|
|
|
self.allow_any = False
|
|
|
|
if credentials_provider is not None:
|
|
self.credentials_providers[domain] = credentials_provider
|
|
else:
|
|
self.log.error("None credentials_provider provided for domain:%s", domain)
|
|
|
|
def curve_user_id(self, client_public_key: bytes) -> str:
|
|
"""Return the User-Id corresponding to a CURVE client's public key
|
|
|
|
Default implementation uses the z85-encoding of the public key.
|
|
|
|
Override to define a custom mapping of public key : user-id
|
|
|
|
This is only called on successful authentication.
|
|
|
|
Parameters
|
|
----------
|
|
client_public_key: bytes
|
|
The client public key used for the given message
|
|
|
|
Returns
|
|
-------
|
|
user_id: unicode
|
|
The user ID as text
|
|
"""
|
|
return z85.encode(client_public_key).decode('ascii')
|
|
|
|
def configure_gssapi(
|
|
self, domain: str = '*', location: Optional[str] = None
|
|
) -> None:
|
|
"""Configure GSSAPI authentication
|
|
|
|
Currently this is a no-op because there is nothing to configure with GSSAPI.
|
|
"""
|
|
|
|
def handle_zap_message(self, msg: List[bytes]):
|
|
"""Perform ZAP authentication"""
|
|
if len(msg) < 6:
|
|
self.log.error("Invalid ZAP message, not enough frames: %r", msg)
|
|
if len(msg) < 2:
|
|
self.log.error("Not enough information to reply")
|
|
else:
|
|
self._send_zap_reply(msg[1], b"400", b"Not enough frames")
|
|
return
|
|
|
|
version, request_id, domain, address, identity, mechanism = msg[:6]
|
|
credentials = msg[6:]
|
|
|
|
domain = domain.decode(self.encoding, 'replace')
|
|
address = address.decode(self.encoding, 'replace')
|
|
|
|
if version != VERSION:
|
|
self.log.error("Invalid ZAP version: %r", msg)
|
|
self._send_zap_reply(request_id, b"400", b"Invalid version")
|
|
return
|
|
|
|
self.log.debug(
|
|
"version: %r, request_id: %r, domain: %r,"
|
|
" address: %r, identity: %r, mechanism: %r",
|
|
version,
|
|
request_id,
|
|
domain,
|
|
address,
|
|
identity,
|
|
mechanism,
|
|
)
|
|
|
|
# Is address is explicitly whitelisted or blacklisted?
|
|
allowed = False
|
|
denied = False
|
|
reason = b"NO ACCESS"
|
|
|
|
if self.whitelist:
|
|
if address in self.whitelist:
|
|
allowed = True
|
|
self.log.debug("PASSED (whitelist) address=%s", address)
|
|
else:
|
|
denied = True
|
|
reason = b"Address not in whitelist"
|
|
self.log.debug("DENIED (not in whitelist) address=%s", address)
|
|
|
|
elif self.blacklist:
|
|
if address in self.blacklist:
|
|
denied = True
|
|
reason = b"Address is blacklisted"
|
|
self.log.debug("DENIED (blacklist) address=%s", address)
|
|
else:
|
|
allowed = True
|
|
self.log.debug("PASSED (not in blacklist) address=%s", address)
|
|
|
|
# Perform authentication mechanism-specific checks if necessary
|
|
username = "anonymous"
|
|
if not denied:
|
|
|
|
if mechanism == b'NULL' and not allowed:
|
|
# For NULL, we allow if the address wasn't blacklisted
|
|
self.log.debug("ALLOWED (NULL)")
|
|
allowed = True
|
|
|
|
elif mechanism == b'PLAIN':
|
|
# For PLAIN, even a whitelisted address must authenticate
|
|
if len(credentials) != 2:
|
|
self.log.error("Invalid PLAIN credentials: %r", credentials)
|
|
self._send_zap_reply(request_id, b"400", b"Invalid credentials")
|
|
return
|
|
username, password = (
|
|
c.decode(self.encoding, 'replace') for c in credentials
|
|
)
|
|
allowed, reason = self._authenticate_plain(domain, username, password)
|
|
|
|
elif mechanism == b'CURVE':
|
|
# For CURVE, even a whitelisted address must authenticate
|
|
if len(credentials) != 1:
|
|
self.log.error("Invalid CURVE credentials: %r", credentials)
|
|
self._send_zap_reply(request_id, b"400", b"Invalid credentials")
|
|
return
|
|
key = credentials[0]
|
|
allowed, reason = self._authenticate_curve(domain, key)
|
|
if allowed:
|
|
username = self.curve_user_id(key)
|
|
|
|
elif mechanism == b'GSSAPI':
|
|
if len(credentials) != 1:
|
|
self.log.error("Invalid GSSAPI credentials: %r", credentials)
|
|
self._send_zap_reply(request_id, b"400", b"Invalid credentials")
|
|
return
|
|
# use principal as user-id for now
|
|
principal = credentials[0]
|
|
username = principal.decode("utf8")
|
|
allowed, reason = self._authenticate_gssapi(domain, principal)
|
|
|
|
if allowed:
|
|
self._send_zap_reply(request_id, b"200", b"OK", username)
|
|
else:
|
|
self._send_zap_reply(request_id, b"400", reason)
|
|
|
|
def _authenticate_plain(
|
|
self, domain: str, username: str, password: str
|
|
) -> Tuple[bool, bytes]:
|
|
"""PLAIN ZAP authentication"""
|
|
allowed = False
|
|
reason = b""
|
|
if self.passwords:
|
|
# If no domain is not specified then use the default domain
|
|
if not domain:
|
|
domain = '*'
|
|
|
|
if domain in self.passwords:
|
|
if username in self.passwords[domain]:
|
|
if password == self.passwords[domain][username]:
|
|
allowed = True
|
|
else:
|
|
reason = b"Invalid password"
|
|
else:
|
|
reason = b"Invalid username"
|
|
else:
|
|
reason = b"Invalid domain"
|
|
|
|
if allowed:
|
|
self.log.debug(
|
|
"ALLOWED (PLAIN) domain=%s username=%s password=%s",
|
|
domain,
|
|
username,
|
|
password,
|
|
)
|
|
else:
|
|
self.log.debug("DENIED %s", reason)
|
|
|
|
else:
|
|
reason = b"No passwords defined"
|
|
self.log.debug("DENIED (PLAIN) %s", reason)
|
|
|
|
return allowed, reason
|
|
|
|
def _authenticate_curve(self, domain: str, client_key: bytes) -> Tuple[bool, bytes]:
|
|
"""CURVE ZAP authentication"""
|
|
allowed = False
|
|
reason = b""
|
|
if self.allow_any:
|
|
allowed = True
|
|
reason = b"OK"
|
|
self.log.debug("ALLOWED (CURVE allow any client)")
|
|
elif self.credentials_providers != {}:
|
|
# If no explicit domain is specified then use the default domain
|
|
if not domain:
|
|
domain = '*'
|
|
|
|
if domain in self.credentials_providers:
|
|
z85_client_key = z85.encode(client_key)
|
|
# Callback to check if key is Allowed
|
|
if self.credentials_providers[domain].callback(domain, z85_client_key):
|
|
allowed = True
|
|
reason = b"OK"
|
|
else:
|
|
reason = b"Unknown key"
|
|
|
|
status = "ALLOWED" if allowed else "DENIED"
|
|
self.log.debug(
|
|
"%s (CURVE auth_callback) domain=%s client_key=%s",
|
|
status,
|
|
domain,
|
|
z85_client_key,
|
|
)
|
|
else:
|
|
reason = b"Unknown domain"
|
|
else:
|
|
# If no explicit domain is specified then use the default domain
|
|
if not domain:
|
|
domain = '*'
|
|
|
|
if domain in self.certs:
|
|
# The certs dict stores keys in z85 format, convert binary key to z85 bytes
|
|
z85_client_key = z85.encode(client_key)
|
|
if self.certs[domain].get(z85_client_key):
|
|
allowed = True
|
|
reason = b"OK"
|
|
else:
|
|
reason = b"Unknown key"
|
|
|
|
status = "ALLOWED" if allowed else "DENIED"
|
|
self.log.debug(
|
|
"%s (CURVE) domain=%s client_key=%s",
|
|
status,
|
|
domain,
|
|
z85_client_key,
|
|
)
|
|
else:
|
|
reason = b"Unknown domain"
|
|
|
|
return allowed, reason
|
|
|
|
def _authenticate_gssapi(self, domain: str, principal: bytes) -> Tuple[bool, bytes]:
|
|
"""Nothing to do for GSSAPI, which has already been handled by an external service."""
|
|
self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal)
|
|
return True, b'OK'
|
|
|
|
def _send_zap_reply(
|
|
self,
|
|
request_id: bytes,
|
|
status_code: bytes,
|
|
status_text: bytes,
|
|
user_id: str = 'anonymous',
|
|
) -> None:
|
|
"""Send a ZAP reply to finish the authentication."""
|
|
user_id = user_id if status_code == b'200' else b''
|
|
if isinstance(user_id, str):
|
|
user_id = user_id.encode(self.encoding, 'replace')
|
|
metadata = b'' # not currently used
|
|
self.log.debug("ZAP reply code=%s text=%s", status_code, status_text)
|
|
reply = [VERSION, request_id, status_code, status_text, user_id, metadata]
|
|
self.zap_socket.send_multipart(reply)
|
|
|
|
|
|
__all__ = ['Authenticator', 'CURVE_ALLOW_ANY']
|