# -*- coding: utf-8 -*- """ flask.testsuite.regression ~~~~~~~~~~~~~~~~~~~~~~~~~~ Tests regressions. :copyright: (c) 2011 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ import os import gc import sys import flask import threading import unittest from werkzeug.exceptions import NotFound from flask.testsuite import FlaskTestCase _gc_lock = threading.Lock() class _NoLeakAsserter(object): def __init__(self, testcase): self.testcase = testcase def __enter__(self): gc.disable() _gc_lock.acquire() loc = flask._request_ctx_stack._local # Force Python to track this dictionary at all times. # This is necessary since Python only starts tracking # dicts if they contain mutable objects. It's a horrible, # horrible hack but makes this kinda testable. loc.__storage__['FOOO'] = [1, 2, 3] gc.collect() self.old_objects = len(gc.get_objects()) def __exit__(self, exc_type, exc_value, tb): if not hasattr(sys, 'getrefcount'): gc.collect() new_objects = len(gc.get_objects()) if new_objects > self.old_objects: self.testcase.fail('Example code leaked') _gc_lock.release() gc.enable() class MemoryTestCase(FlaskTestCase): def assert_no_leak(self): return _NoLeakAsserter(self) def test_memory_consumption(self): app = flask.Flask(__name__) @app.route('/') def index(): return flask.render_template('simple_template.html', whiskey=42) def fire(): with app.test_client() as c: rv = c.get('/') self.assert_equal(rv.status_code, 200) self.assert_equal(rv.data, b'

42

') # Trigger caches fire() # This test only works on CPython 2.7. if sys.version_info >= (2, 7) and \ not hasattr(sys, 'pypy_translation_info'): with self.assert_no_leak(): for x in range(10): fire() def test_safe_join_toplevel_pardir(self): from flask.helpers import safe_join with self.assert_raises(NotFound): safe_join('/foo', '..') class ExceptionTestCase(FlaskTestCase): def test_aborting(self): class Foo(Exception): whatever = 42 app = flask.Flask(__name__) app.testing = True @app.errorhandler(Foo) def handle_foo(e): return str(e.whatever) @app.route('/') def index(): raise flask.abort(flask.redirect(flask.url_for('test'))) @app.route('/test') def test(): raise Foo() with app.test_client() as c: rv = c.get('/') self.assertEqual(rv.headers['Location'], 'http://localhost/test') rv = c.get('/test') self.assertEqual(rv.data, b'42') def suite(): suite = unittest.TestSuite() if os.environ.get('RUN_FLASK_MEMORY_TESTS') == '1': suite.addTest(unittest.makeSuite(MemoryTestCase)) suite.addTest(unittest.makeSuite(ExceptionTestCase)) return suite