#!/usr/bin/env python from __future__ import absolute_import, division, print_function, with_statement import base64 import binascii from contextlib import closing import functools import sys import threading from tornado.escape import utf8 from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop from tornado.iostream import IOStream from tornado.log import gen_log from tornado import netutil from tornado.stack_context import ExceptionStackContext, NullContext from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog from tornado.test.util import unittest from tornado.util import u, bytes_type from tornado.web import Application, RequestHandler, url try: from io import BytesIO # python 3 except ImportError: from cStringIO import StringIO as BytesIO class HelloWorldHandler(RequestHandler): def get(self): name = self.get_argument("name", "world") self.set_header("Content-Type", "text/plain") self.finish("Hello %s!" % name) class PostHandler(RequestHandler): def post(self): self.finish("Post arg1: %s, arg2: %s" % ( self.get_argument("arg1"), self.get_argument("arg2"))) class ChunkHandler(RequestHandler): def get(self): self.write("asdf") self.flush() self.write("qwer") class AuthHandler(RequestHandler): def get(self): self.finish(self.request.headers["Authorization"]) class CountdownHandler(RequestHandler): def get(self, count): count = int(count) if count > 0: self.redirect(self.reverse_url("countdown", count - 1)) else: self.write("Zero") class EchoPostHandler(RequestHandler): def post(self): self.write(self.request.body) class UserAgentHandler(RequestHandler): def get(self): self.write(self.request.headers.get('User-Agent', 'User agent not set')) class ContentLength304Handler(RequestHandler): def get(self): self.set_status(304) self.set_header('Content-Length', 42) def _clear_headers_for_304(self): # Tornado strips content-length from 304 responses, but here we # want to simulate servers that include the headers anyway. pass class AllMethodsHandler(RequestHandler): SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',) def method(self): self.write(self.request.method) get = post = put = delete = options = patch = other = method # These tests end up getting run redundantly: once here with the default # HTTPClient implementation, and then again in each implementation's own # test suite. class HTTPClientCommonTestCase(AsyncHTTPTestCase): def get_app(self): return Application([ url("/hello", HelloWorldHandler), url("/post", PostHandler), url("/chunk", ChunkHandler), url("/auth", AuthHandler), url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), url("/echopost", EchoPostHandler), url("/user_agent", UserAgentHandler), url("/304_with_content_length", ContentLength304Handler), url("/all_methods", AllMethodsHandler), ], gzip=True) def test_hello_world(self): response = self.fetch("/hello") self.assertEqual(response.code, 200) self.assertEqual(response.headers["Content-Type"], "text/plain") self.assertEqual(response.body, b"Hello world!") self.assertEqual(int(response.request_time), 0) response = self.fetch("/hello?name=Ben") self.assertEqual(response.body, b"Hello Ben!") def test_streaming_callback(self): # streaming_callback is also tested in test_chunked chunks = [] response = self.fetch("/hello", streaming_callback=chunks.append) # with streaming_callback, data goes to the callback and not response.body self.assertEqual(chunks, [b"Hello world!"]) self.assertFalse(response.body) def test_post(self): response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar") self.assertEqual(response.code, 200) self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") def test_chunked(self): response = self.fetch("/chunk") self.assertEqual(response.body, b"asdfqwer") chunks = [] response = self.fetch("/chunk", streaming_callback=chunks.append) self.assertEqual(chunks, [b"asdf", b"qwer"]) self.assertFalse(response.body) def test_chunked_close(self): # test case in which chunks spread read-callback processing # over several ioloop iterations, but the connection is already closed. sock, port = bind_unused_port() with closing(sock): def write_response(stream, request_data): stream.write(b"""\ HTTP/1.1 200 OK Transfer-Encoding: chunked 1 1 1 2 0 """.replace(b"\n", b"\r\n"), callback=stream.close) def accept_callback(conn, address): # fake an HTTP server using chunked encoding where the final chunks # and connection close all happen at once stream = IOStream(conn, io_loop=self.io_loop) stream.read_until(b"\r\n\r\n", functools.partial(write_response, stream)) netutil.add_accept_handler(sock, accept_callback, self.io_loop) self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop) resp = self.wait() resp.rethrow() self.assertEqual(resp.body, b"12") self.io_loop.remove_handler(sock.fileno()) def test_streaming_stack_context(self): chunks = [] exc_info = [] def error_handler(typ, value, tb): exc_info.append((typ, value, tb)) return True def streaming_cb(chunk): chunks.append(chunk) if chunk == b'qwer': 1 / 0 with ExceptionStackContext(error_handler): self.fetch('/chunk', streaming_callback=streaming_cb) self.assertEqual(chunks, [b'asdf', b'qwer']) self.assertEqual(1, len(exc_info)) self.assertIs(exc_info[0][0], ZeroDivisionError) def test_basic_auth(self): self.assertEqual(self.fetch("/auth", auth_username="Aladdin", auth_password="open sesame").body, b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") def test_basic_auth_explicit_mode(self): self.assertEqual(self.fetch("/auth", auth_username="Aladdin", auth_password="open sesame", auth_mode="basic").body, b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") def test_unsupported_auth_mode(self): # curl and simple clients handle errors a bit differently; the # important thing is that they don't fall back to basic auth # on an unknown mode. with ExpectLog(gen_log, "uncaught exception", required=False): with self.assertRaises((ValueError, HTTPError)): response = self.fetch("/auth", auth_username="Aladdin", auth_password="open sesame", auth_mode="asdf") response.rethrow() def test_follow_redirect(self): response = self.fetch("/countdown/2", follow_redirects=False) self.assertEqual(302, response.code) self.assertTrue(response.headers["Location"].endswith("/countdown/1")) response = self.fetch("/countdown/2") self.assertEqual(200, response.code) self.assertTrue(response.effective_url.endswith("/countdown/0")) self.assertEqual(b"Zero", response.body) def test_credentials_in_url(self): url = self.get_url("/auth").replace("http://", "http://me:secret@") self.http_client.fetch(url, self.stop) response = self.wait() self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), response.body) def test_body_encoding(self): unicode_body = u("\xe9") byte_body = binascii.a2b_hex(b"e9") # unicode string in body gets converted to utf8 response = self.fetch("/echopost", method="POST", body=unicode_body, headers={"Content-Type": "application/blah"}) self.assertEqual(response.headers["Content-Length"], "2") self.assertEqual(response.body, utf8(unicode_body)) # byte strings pass through directly response = self.fetch("/echopost", method="POST", body=byte_body, headers={"Content-Type": "application/blah"}) self.assertEqual(response.headers["Content-Length"], "1") self.assertEqual(response.body, byte_body) # Mixing unicode in headers and byte string bodies shouldn't # break anything response = self.fetch("/echopost", method="POST", body=byte_body, headers={"Content-Type": "application/blah"}, user_agent=u("foo")) self.assertEqual(response.headers["Content-Length"], "1") self.assertEqual(response.body, byte_body) def test_types(self): response = self.fetch("/hello") self.assertEqual(type(response.body), bytes_type) self.assertEqual(type(response.headers["Content-Type"]), str) self.assertEqual(type(response.code), int) self.assertEqual(type(response.effective_url), str) def test_header_callback(self): first_line = [] headers = {} chunks = [] def header_callback(header_line): if header_line.startswith('HTTP/'): first_line.append(header_line) elif header_line != '\r\n': k, v = header_line.split(':', 1) headers[k] = v.strip() def streaming_callback(chunk): # All header callbacks are run before any streaming callbacks, # so the header data is available to process the data as it # comes in. self.assertEqual(headers['Content-Type'], 'text/html; charset=UTF-8') chunks.append(chunk) self.fetch('/chunk', header_callback=header_callback, streaming_callback=streaming_callback) self.assertEqual(len(first_line), 1) self.assertRegexpMatches(first_line[0], 'HTTP/1.[01] 200 OK\r\n') self.assertEqual(chunks, [b'asdf', b'qwer']) def test_header_callback_stack_context(self): exc_info = [] def error_handler(typ, value, tb): exc_info.append((typ, value, tb)) return True def header_callback(header_line): if header_line.startswith('Content-Type:'): 1 / 0 with ExceptionStackContext(error_handler): self.fetch('/chunk', header_callback=header_callback) self.assertEqual(len(exc_info), 1) self.assertIs(exc_info[0][0], ZeroDivisionError) def test_configure_defaults(self): defaults = dict(user_agent='TestDefaultUserAgent') # Construct a new instance of the configured client class client = self.http_client.__class__(self.io_loop, force_instance=True, defaults=defaults) client.fetch(self.get_url('/user_agent'), callback=self.stop) response = self.wait() self.assertEqual(response.body, b'TestDefaultUserAgent') client.close() def test_304_with_content_length(self): # According to the spec 304 responses SHOULD NOT include # Content-Length or other entity headers, but some servers do it # anyway. # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5 response = self.fetch('/304_with_content_length') self.assertEqual(response.code, 304) self.assertEqual(response.headers['Content-Length'], '42') def test_final_callback_stack_context(self): # The final callback should be run outside of the httpclient's # stack_context. We want to ensure that there is not stack_context # between the user's callback and the IOLoop, so monkey-patch # IOLoop.handle_callback_exception and disable the test harness's # context with a NullContext. # Note that this does not apply to secondary callbacks (header # and streaming_callback), as errors there must be seen as errors # by the http client so it can clean up the connection. exc_info = [] def handle_callback_exception(callback): exc_info.append(sys.exc_info()) self.stop() self.io_loop.handle_callback_exception = handle_callback_exception with NullContext(): self.http_client.fetch(self.get_url('/hello'), lambda response: 1 / 0) self.wait() self.assertEqual(exc_info[0][0], ZeroDivisionError) @gen_test def test_future_interface(self): response = yield self.http_client.fetch(self.get_url('/hello')) self.assertEqual(response.body, b'Hello world!') @gen_test def test_future_http_error(self): try: yield self.http_client.fetch(self.get_url('/notfound')) except HTTPError as e: self.assertEqual(e.code, 404) self.assertEqual(e.response.code, 404) @gen_test def test_reuse_request_from_response(self): # The response.request attribute should be an HTTPRequest, not # a _RequestProxy. # This test uses self.http_client.fetch because self.fetch calls # self.get_url on the input unconditionally. url = self.get_url('/hello') response = yield self.http_client.fetch(url) self.assertEqual(response.request.url, url) self.assertTrue(isinstance(response.request, HTTPRequest)) response2 = yield self.http_client.fetch(response.request) self.assertEqual(response2.body, b'Hello world!') def test_all_methods(self): for method in ['GET', 'DELETE', 'OPTIONS']: response = self.fetch('/all_methods', method=method) self.assertEqual(response.body, utf8(method)) for method in ['POST', 'PUT', 'PATCH']: response = self.fetch('/all_methods', method=method, body=b'') self.assertEqual(response.body, utf8(method)) response = self.fetch('/all_methods', method='HEAD') self.assertEqual(response.body, b'') response = self.fetch('/all_methods', method='OTHER', allow_nonstandard_methods=True) self.assertEqual(response.body, b'OTHER') class RequestProxyTest(unittest.TestCase): def test_request_set(self): proxy = _RequestProxy(HTTPRequest('http://example.com/', user_agent='foo'), dict()) self.assertEqual(proxy.user_agent, 'foo') def test_default_set(self): proxy = _RequestProxy(HTTPRequest('http://example.com/'), dict(network_interface='foo')) self.assertEqual(proxy.network_interface, 'foo') def test_both_set(self): proxy = _RequestProxy(HTTPRequest('http://example.com/', proxy_host='foo'), dict(proxy_host='bar')) self.assertEqual(proxy.proxy_host, 'foo') def test_neither_set(self): proxy = _RequestProxy(HTTPRequest('http://example.com/'), dict()) self.assertIs(proxy.auth_username, None) def test_bad_attribute(self): proxy = _RequestProxy(HTTPRequest('http://example.com/'), dict()) with self.assertRaises(AttributeError): proxy.foo def test_defaults_none(self): proxy = _RequestProxy(HTTPRequest('http://example.com/'), None) self.assertIs(proxy.auth_username, None) class HTTPResponseTestCase(unittest.TestCase): def test_str(self): response = HTTPResponse(HTTPRequest('http://example.com'), 200, headers={}, buffer=BytesIO()) s = str(response) self.assertTrue(s.startswith('HTTPResponse(')) self.assertIn('code=200', s) class SyncHTTPClientTest(unittest.TestCase): def setUp(self): if IOLoop.configured_class().__name__ == 'TwistedIOLoop': # TwistedIOLoop only supports the global reactor, so we can't have # separate IOLoops for client and server threads. raise unittest.SkipTest( 'Sync HTTPClient not compatible with TwistedIOLoop') self.server_ioloop = IOLoop() sock, self.port = bind_unused_port() app = Application([('/', HelloWorldHandler)]) server = HTTPServer(app, io_loop=self.server_ioloop) server.add_socket(sock) self.server_thread = threading.Thread(target=self.server_ioloop.start) self.server_thread.start() self.http_client = HTTPClient() def tearDown(self): self.server_ioloop.add_callback(self.server_ioloop.stop) self.server_thread.join() self.http_client.close() self.server_ioloop.close(all_fds=True) def get_url(self, path): return 'http://localhost:%d%s' % (self.port, path) def test_sync_client(self): response = self.http_client.fetch(self.get_url('/')) self.assertEqual(b'Hello world!', response.body) def test_sync_client_error(self): # Synchronous HTTPClient raises errors directly; no need for # response.rethrow() with self.assertRaises(HTTPError) as assertion: self.http_client.fetch(self.get_url('/notfound')) self.assertEqual(assertion.exception.code, 404)