from __future__ import print_function, division from contextlib import contextmanager import unittest import errno import os import gevent.testing as greentest from gevent.testing import PY3 from gevent.testing import DEFAULT_SOCKET_TIMEOUT as _DEFAULT_SOCKET_TIMEOUT from gevent.testing.sockets import tcp_listener from gevent import socket import gevent from gevent.server import StreamServer class SimpleStreamServer(StreamServer): def handle(self, client_socket, _address): # pylint:disable=method-hidden fd = client_socket.makefile() try: request_line = fd.readline() if not request_line: return try: _method, path, _rest = request_line.split(' ', 3) except Exception: print('Failed to parse request line: %r' % (request_line, )) raise if path == '/ping': client_socket.sendall(b'HTTP/1.0 200 OK\r\n\r\nPONG') elif path in ['/long', '/short']: client_socket.sendall(b'hello') while True: data = client_socket.recv(1) if not data: break else: client_socket.sendall(b'HTTP/1.0 404 WTF?\r\n\r\n') finally: fd.close() class Settings(object): ServerClass = StreamServer ServerSubClass = SimpleStreamServer restartable = True close_socket_detected = True @staticmethod def assertAcceptedConnectionError(inst): with inst.makefile() as conn: result = conn.read() inst.assertFalse(result) assert500 = assertAcceptedConnectionError @staticmethod def assert503(inst): # regular reads timeout inst.assert500() # attempt to send anything reset the connection try: inst.send_request() except socket.error as ex: if ex.args[0] not in greentest.CONN_ABORTED_ERRORS: raise @staticmethod def assertPoolFull(inst): with inst.assertRaises(socket.timeout): inst.assertRequestSucceeded(timeout=0.01) @staticmethod def fill_default_server_args(inst, kwargs): kwargs.setdefault('spawn', inst.get_spawn()) return kwargs class TestCase(greentest.TestCase): # pylint: disable=too-many-public-methods __timeout__ = greentest.LARGE_TIMEOUT Settings = Settings server = None def cleanup(self): if getattr(self, 'server', None) is not None: self.server.stop() self.server = None def get_listener(self): return self._close_on_teardown(tcp_listener(backlog=5)) def get_server_host_port_family(self): server_host = self.server.server_host if not server_host: server_host = greentest.DEFAULT_LOCAL_HOST_ADDR elif server_host == '::': server_host = greentest.DEFAULT_LOCAL_HOST_ADDR6 try: family = self.server.socket.family except AttributeError: # server deletes socket when closed family = socket.AF_INET return server_host, self.server.server_port, family @contextmanager def makefile(self, timeout=_DEFAULT_SOCKET_TIMEOUT, bufsize=1, include_raw_socket=False): server_host, server_port, family = self.get_server_host_port_family() bufarg = 'buffering' if PY3 else 'bufsize' makefile_kwargs = {bufarg: bufsize} if PY3: # Under Python3, you can't read and write to the same # makefile() opened in r, and r+ is not allowed makefile_kwargs['mode'] = 'rwb' with socket.socket(family=family) as sock: rconn = None # We want the socket to be accessible from the fileobject # we return. On Python 2, natively this is available as # _sock, but Python 3 doesn't have that. sock.connect((server_host, server_port)) sock.settimeout(timeout) with sock.makefile(**makefile_kwargs) as rconn: result = rconn if not include_raw_socket else (rconn, sock) yield result def send_request(self, url='/', timeout=_DEFAULT_SOCKET_TIMEOUT, bufsize=1): with self.makefile(timeout=timeout, bufsize=bufsize) as conn: self.send_request_to_fd(conn, url) def send_request_to_fd(self, fd, url='/'): fd.write(('GET %s HTTP/1.0\r\n\r\n' % url).encode('latin-1')) fd.flush() def assertConnectionRefused(self): with self.assertRaises(socket.error) as exc: with self.makefile() as conn: conn.close() ex = exc.exception self.assertIn(ex.args[0], (errno.ECONNREFUSED, errno.EADDRNOTAVAIL, errno.ECONNRESET, errno.ECONNABORTED), (ex, ex.args)) def assert500(self): self.Settings.assert500(self) def assert503(self): self.Settings.assert503(self) def assertAcceptedConnectionError(self): self.Settings.assertAcceptedConnectionError(self) def assertPoolFull(self): self.Settings.assertPoolFull(self) def assertNotAccepted(self): with self.makefile(include_raw_socket=True) as (conn, sock): conn.write(b'GET / HTTP/1.0\r\n\r\n') conn.flush() result = b'' try: while True: data = sock.recv(1) if not data: break result += data except socket.timeout: self.assertFalse(result) return self.assertTrue(result.startswith(b'HTTP/1.0 500 Internal Server Error'), repr(result)) def assertRequestSucceeded(self, timeout=_DEFAULT_SOCKET_TIMEOUT): with self.makefile(timeout=timeout) as conn: conn.write(b'GET /ping HTTP/1.0\r\n\r\n') result = conn.read() self.assertTrue(result.endswith(b'\r\n\r\nPONG'), repr(result)) def start_server(self): self.server.start() self.assertRequestSucceeded() self.assertRequestSucceeded() def stop_server(self): self.server.stop() self.assertConnectionRefused() def report_netstat(self, _msg): # At one point this would call 'sudo netstat -anp | grep PID' # with os.system. We can probably do better with psutil. return def _create_server(self): return self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0)) def init_server(self): self.server = self._create_server() self.server.start() gevent.sleep() @property def socket(self): return self.server.socket def _test_invalid_callback(self): try: self.server = self.ServerClass((greentest.DEFAULT_BIND_ADDR, 0), lambda: None) self.server.start() self.expect_one_error() self.assert500() self.assert_error(TypeError) finally: self.server.stop() # XXX: There's something unreachable (with a traceback?) # We need to clear it to make the leak checks work on Travis; # so far I can't reproduce it locally on OS X. import gc; gc.collect() def fill_default_server_args(self, kwargs): return self.Settings.fill_default_server_args(self, kwargs) def ServerClass(self, *args, **kwargs): return self.Settings.ServerClass(*args, **self.fill_default_server_args(kwargs)) def ServerSubClass(self, *args, **kwargs): return self.Settings.ServerSubClass(*args, **self.fill_default_server_args(kwargs)) def get_spawn(self): return None class TestDefaultSpawn(TestCase): def get_spawn(self): return gevent.spawn def _test_server_start_stop(self, restartable): self.report_netstat('before start') self.start_server() self.report_netstat('after start') if restartable and self.Settings.restartable: self.server.stop_accepting() self.report_netstat('after stop_accepting') self.assertNotAccepted() self.server.start_accepting() self.report_netstat('after start_accepting') self.assertRequestSucceeded() self.stop_server() self.report_netstat('after stop') def test_backlog_is_not_accepted_for_socket(self): self.switch_expected = False with self.assertRaises(TypeError): self.ServerClass(self.get_listener(), backlog=25, handle=False) def test_backlog_is_accepted_for_address(self): self.server = self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0), backlog=25) self.assertConnectionRefused() self._test_server_start_stop(restartable=False) def test_subclass_just_create(self): self.server = self.ServerSubClass(self.get_listener()) self.assertNotAccepted() def test_subclass_with_socket(self): self.server = self.ServerSubClass(self.get_listener()) # the connection won't be refused, because there exists a # listening socket, but it won't be handled also self.assertNotAccepted() self._test_server_start_stop(restartable=True) def test_subclass_with_address(self): self.server = self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0)) self.assertConnectionRefused() self._test_server_start_stop(restartable=True) def test_invalid_callback(self): self._test_invalid_callback() @greentest.reraises_flaky_timeout(socket.timeout) def _test_serve_forever(self): g = gevent.spawn(self.server.serve_forever) try: gevent.sleep(0.01) self.assertRequestSucceeded() self.server.stop() self.assertFalse(self.server.started) self.assertConnectionRefused() finally: g.kill() g.get() self.server.stop() def test_serve_forever(self): self.server = self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0)) self.assertFalse(self.server.started) self.assertConnectionRefused() self._test_serve_forever() def test_serve_forever_after_start(self): self.server = self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0)) self.assertConnectionRefused() self.assertFalse(self.server.started) self.server.start() self.assertTrue(self.server.started) self._test_serve_forever() def test_server_closes_client_sockets(self): self.server = self.ServerClass((greentest.DEFAULT_BIND_ADDR, 0), lambda *args: []) self.server.start() with self.makefile() as conn: self.send_request_to_fd(conn) # use assert500 below? with gevent.Timeout._start_new_or_dummy(1): try: result = conn.read() if result: assert result.startswith('HTTP/1.0 500 Internal Server Error'), repr(result) except socket.error as ex: if ex.args[0] == 10053: pass # "established connection was aborted by the software in your host machine" elif ex.args[0] == errno.ECONNRESET: pass else: raise self.stop_server() @property def socket(self): return self.server.socket def test_error_in_spawn(self): self.init_server() self.assertTrue(self.server.started) error = ExpectedError('test_error_in_spawn') self.server._spawn = lambda *args: gevent.getcurrent().throw(error) self.expect_one_error() self.assertAcceptedConnectionError() self.assert_error(ExpectedError, error) def test_server_repr_when_handle_is_instancemethod(self): # PR 501 self.init_server() assert self.server.started self.assertIn('Server', repr(self.server)) self.server.set_handle(self.server.handle) self.assertIn('handle=', repr(self.server)) self.server.set_handle(self.test_server_repr_when_handle_is_instancemethod) self.assertIn('test_server_repr_when_handle_is_instancemethod', repr(self.server)) def handle(): pass self.server.set_handle(handle) self.assertIn('handle=