use the standard socket library to validate the ip address argument

pull/961/head
Daniel Pavel 5 years ago
parent a334ef28e7
commit 99c6247baf

@ -1,7 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# This file is part of the Calibre-Web (https://github.com/janeczku/calibre-web) # This file is part of the Calibre-Web (https://github.com/janeczku/calibre-web)
# Copyright (C) 2018 OzzieIsaacs # Copyright (C) 2018 OzzieIsaacs
# #
@ -22,50 +20,17 @@ from __future__ import division, print_function, unicode_literals
import sys import sys
import os import os
import argparse import argparse
import socket
from .constants import CONFIG_DIR as _CONFIG_DIR from .constants import CONFIG_DIR as _CONFIG_DIR
from .constants import STABLE_VERSION as _STABLE_VERSION from .constants import STABLE_VERSION as _STABLE_VERSION
from .constants import NIGHTLY_VERSION as _NIGHTLY_VERSION from .constants import NIGHTLY_VERSION as _NIGHTLY_VERSION
VALID_CHARACTERS = 'ABCDEFabcdef:0123456789'
ipv6 = False
def version_info(): def version_info():
if _NIGHTLY_VERSION[1].startswith('$Format'): if _NIGHTLY_VERSION[1].startswith('$Format'):
return "Calibre-Web version: %s - unkown git-clone" % _STABLE_VERSION['version'] return "Calibre-Web version: %s - unkown git-clone" % _STABLE_VERSION['version']
else: return "Calibre-Web version: %s -%s" % (_STABLE_VERSION['version'], _NIGHTLY_VERSION[1])
return "Calibre-Web version: %s -%s" % (_STABLE_VERSION['version'],_NIGHTLY_VERSION[1])
def validate_ip4(address):
address_list = address.split('.')
if len(address_list) != 4:
return False
for val in address_list:
if not val.isdigit():
return False
i = int(val)
if i < 0 or i > 255:
return False
return True
def validate_ip6(address):
address_list = address.split(':')
return (
len(address_list) == 8
and all(len(current) <= 4 for current in address_list)
and all(current in VALID_CHARACTERS for current in address)
)
def validate_ip(address):
if validate_ip4(address) or ipv6:
return address
print("IP address is invalid. Exiting")
sys.exit(1)
parser = argparse.ArgumentParser(description='Calibre Web is a web app' parser = argparse.ArgumentParser(description='Calibre Web is a web app'
@ -95,8 +60,8 @@ if sys.version_info < (3, 0):
args.s = args.s.decode('utf-8') args.s = args.s.decode('utf-8')
settingspath = args.p or os.path.join(_CONFIG_DIR, "app.db") settingspath = args.p or os.path.join(_CONFIG_DIR, "app.db")
gdpath = args.g or os.path.join(_CONFIG_DIR, "gdrive.db") gdpath = args.g or os.path.join(_CONFIG_DIR, "gdrive.db")
# handle and check parameter for ssl encryption # handle and check parameter for ssl encryption
certfilepath = None certfilepath = None
@ -108,7 +73,7 @@ if args.c:
print("Certfilepath is invalid. Exiting...") print("Certfilepath is invalid. Exiting...")
sys.exit(1) sys.exit(1)
if args.c is "": if args.c == "":
certfilepath = "" certfilepath = ""
if args.k: if args.k:
@ -122,15 +87,26 @@ if (args.k and not args.c) or (not args.k and args.c):
print("Certfile and Keyfile have to be used together. Exiting...") print("Certfile and Keyfile have to be used together. Exiting...")
sys.exit(1) sys.exit(1)
if args.k is "": if args.k == "":
keyfilepath = "" keyfilepath = ""
# handle and check ipadress argument # handle and check ipadress argument
if args.i: ipadress = args.i or None
ipv6 = validate_ip6(args.i) if ipadress:
ipadress = validate_ip(args.i) try:
else: # try to parse the given ip address with socket
ipadress = None if hasattr(socket, 'inet_pton'):
if ':' in ipadress:
socket.inet_pton(socket.AF_INET6, ipadress)
else:
socket.inet_pton(socket.AF_INET, ipadress)
else:
# on windows python < 3.4, inet_pton is not available
# inet_atom only handles IPv4 addresses
socket.inet_aton(ipadress)
except socket.error as err:
print(ipadress, ':', err)
sys.exit(1)
# handle and check user password argument # handle and check user password argument
user_password = args.s or None user_password = args.s or None

@ -136,9 +136,6 @@ class _ConfigSQL(object):
def get_config_ipaddress(self): def get_config_ipaddress(self):
return cli.ipadress or "" return cli.ipadress or ""
def get_ipaddress_type(self):
return cli.ipv6
def _has_role(self, role_flag): def _has_role(self, role_flag):
return constants.has_flag(self.config_default_role, role_flag) return constants.has_flag(self.config_default_role, role_flag)

@ -128,4 +128,3 @@ NIGHTLY_VERSION[1] = '$Format:%cI$'
# clean-up the module namespace # clean-up the module namespace
del sys, os, namedtuple del sys, os, namedtuple

@ -43,7 +43,14 @@ from . import logger
log = logger.create() log = logger.create()
class WebServer:
def _readable_listen_address(address, port):
if ':' in address:
address = "[" + address + "]"
return '%s:%s' % (address, port)
class WebServer(object):
def __init__(self): def __init__(self):
signal.signal(signal.SIGINT, self._killServer) signal.signal(signal.SIGINT, self._killServer)
@ -55,14 +62,12 @@ class WebServer:
self.app = None self.app = None
self.listen_address = None self.listen_address = None
self.listen_port = None self.listen_port = None
self.IPV6 = False
self.unix_socket_file = None self.unix_socket_file = None
self.ssl_args = None self.ssl_args = None
def init_app(self, application, config): def init_app(self, application, config):
self.app = application self.app = application
self.listen_address = config.get_config_ipaddress() self.listen_address = config.get_config_ipaddress()
self.IPV6 = config.get_ipaddress_type()
self.listen_port = config.config_port self.listen_port = config.config_port
if config.config_access_log: if config.config_access_log:
@ -77,8 +82,7 @@ class WebServer:
keyfile_path = config.get_config_keyfile() keyfile_path = config.get_config_keyfile()
if certfile_path and keyfile_path: if certfile_path and keyfile_path:
if os.path.isfile(certfile_path) and os.path.isfile(keyfile_path): if os.path.isfile(certfile_path) and os.path.isfile(keyfile_path):
self.ssl_args = {"certfile": certfile_path, self.ssl_args = dict(certfile=certfile_path, keyfile=keyfile_path)
"keyfile": keyfile_path}
else: else:
log.warning('The specified paths for the ssl certificate file and/or key file seem to be broken. Ignoring ssl.') log.warning('The specified paths for the ssl certificate file and/or key file seem to be broken. Ignoring ssl.')
log.warning('Cert path: %s', certfile_path) log.warning('Cert path: %s', certfile_path)
@ -106,32 +110,33 @@ class WebServer:
if os.name != 'nt': if os.name != 'nt':
unix_socket_file = os.environ.get("CALIBRE_UNIX_SOCKET") unix_socket_file = os.environ.get("CALIBRE_UNIX_SOCKET")
if unix_socket_file: if unix_socket_file:
output = "socket:" + unix_socket_file + ":" + str(self.listen_port) return self._make_gevent_unix_socket(unix_socket_file), "unix:" + unix_socket_file
return self._make_gevent_unix_socket(unix_socket_file), output
if self.listen_address: if self.listen_address:
return (self.listen_address, self.listen_port), self._get_readable_listen_address() return (self.listen_address, self.listen_port), None
if os.name == 'nt': if os.name == 'nt':
self.listen_address = '0.0.0.0' self.listen_address = '0.0.0.0'
return (self.listen_address, self.listen_port), self._get_readable_listen_address() return (self.listen_address, self.listen_port), None
address = ('', self.listen_port)
try: try:
address = ('::', self.listen_port)
sock = WSGIServer.get_listener(address, family=socket.AF_INET6) sock = WSGIServer.get_listener(address, family=socket.AF_INET6)
output = self._get_readable_listen_address(True)
except socket.error as ex: except socket.error as ex:
log.error('%s', ex) log.error('%s', ex)
log.warning('Unable to listen on "", trying on IPv4 only...') log.warning('Unable to listen on "", trying on IPv4 only...')
output = self._get_readable_listen_address(False) address = ('', self.listen_port)
sock = WSGIServer.get_listener(address, family=socket.AF_INET) sock = WSGIServer.get_listener(address, family=socket.AF_INET)
return sock, output
return sock, _readable_listen_address(*address)
def _start_gevent(self): def _start_gevent(self):
ssl_args = self.ssl_args or {} ssl_args = self.ssl_args or {}
try: try:
sock, output = self._make_gevent_socket() sock, output = self._make_gevent_socket()
if output is None:
output = _readable_listen_address(self.listen_address, self.listen_port)
log.info('Starting Gevent server on %s', output) log.info('Starting Gevent server on %s', output)
self.wsgiserver = WSGIServer(sock, self.app, log=self.access_logger, spawn=Pool(), **ssl_args) self.wsgiserver = WSGIServer(sock, self.app, log=self.access_logger, spawn=Pool(), **ssl_args)
self.wsgiserver.serve_forever() self.wsgiserver.serve_forever()
@ -141,30 +146,18 @@ class WebServer:
self.unix_socket_file = None self.unix_socket_file = None
def _start_tornado(self): def _start_tornado(self):
log.info('Starting Tornado server on %s', self._get_readable_listen_address()) log.info('Starting Tornado server on %s', _readable_listen_address(self.listen_address, self.listen_port))
# Max Buffersize set to 200MB ) # Max Buffersize set to 200MB )
http_server = HTTPServer(WSGIContainer(self.app), http_server = HTTPServer(WSGIContainer(self.app),
max_buffer_size = 209700000, max_buffer_size=209700000,
ssl_options=self.ssl_args) ssl_options=self.ssl_args)
http_server.listen(self.listen_port, self.listen_address) http_server.listen(self.listen_port, self.listen_address)
self.wsgiserver=IOLoop.instance() self.wsgiserver = IOLoop.instance()
self.wsgiserver.start() self.wsgiserver.start()
# wait for stop signal # wait for stop signal
self.wsgiserver.close(True) self.wsgiserver.close(True)
def _get_readable_listen_address(self, ipV6=False):
if self.listen_address == "":
listen_string = '""'
else:
ipV6 = self.IPV6
listen_string = self.listen_address
if ipV6:
adress = "[" + listen_string + "]"
else:
adress = listen_string
return adress + ":" + str(self.listen_port)
def start(self): def start(self):
try: try:
if _GEVENT: if _GEVENT:
@ -191,7 +184,7 @@ class WebServer:
os.execv(sys.executable, arguments) os.execv(sys.executable, arguments)
return True return True
def _killServer(self, signum, frame): def _killServer(self, ignored_signum, ignored_frame):
self.stop() self.stop()
def stop(self, restart=False): def stop(self, restart=False):

@ -29,10 +29,7 @@ _ldap = LDAP()
def init_app(app, config): def init_app(app, config):
global _ldap
if config.config_login_type != constants.LOGIN_LDAP: if config.config_login_type != constants.LOGIN_LDAP:
_ldap = None
return return
app.config['LDAP_HOST'] = config.config_ldap_provider_url app.config['LDAP_HOST'] = config.config_ldap_provider_url

Loading…
Cancel
Save