You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
331 lines
11 KiB
Python
331 lines
11 KiB
Python
10 years ago
|
#!/usr/bin/env python
|
||
|
#
|
||
|
# Copyright 2012 Facebook
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||
|
# not use this file except in compliance with the License. You may obtain
|
||
|
# a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
# License for the specific language governing permissions and limitations
|
||
|
# under the License.
|
||
|
from __future__ import absolute_import, division, print_function, with_statement
|
||
|
|
||
|
import logging
|
||
|
import re
|
||
|
import socket
|
||
|
import sys
|
||
|
import traceback
|
||
|
|
||
|
from tornado.concurrent import Future, return_future, ReturnValueIgnoredError
|
||
|
from tornado.escape import utf8, to_unicode
|
||
|
from tornado import gen
|
||
|
from tornado.iostream import IOStream
|
||
|
from tornado import stack_context
|
||
|
from tornado.tcpserver import TCPServer
|
||
|
from tornado.testing import AsyncTestCase, LogTrapTestCase, bind_unused_port, gen_test
|
||
|
|
||
|
|
||
|
class ReturnFutureTest(AsyncTestCase):
|
||
|
@return_future
|
||
|
def sync_future(self, callback):
|
||
|
callback(42)
|
||
|
|
||
|
@return_future
|
||
|
def async_future(self, callback):
|
||
|
self.io_loop.add_callback(callback, 42)
|
||
|
|
||
|
@return_future
|
||
|
def immediate_failure(self, callback):
|
||
|
1 / 0
|
||
|
|
||
|
@return_future
|
||
|
def delayed_failure(self, callback):
|
||
|
self.io_loop.add_callback(lambda: 1 / 0)
|
||
|
|
||
|
@return_future
|
||
|
def return_value(self, callback):
|
||
|
# Note that the result of both running the callback and returning
|
||
|
# a value (or raising an exception) is unspecified; with current
|
||
|
# implementations the last event prior to callback resolution wins.
|
||
|
return 42
|
||
|
|
||
|
@return_future
|
||
|
def no_result_future(self, callback):
|
||
|
callback()
|
||
|
|
||
|
def test_immediate_failure(self):
|
||
|
with self.assertRaises(ZeroDivisionError):
|
||
|
# The caller sees the error just like a normal function.
|
||
|
self.immediate_failure(callback=self.stop)
|
||
|
# The callback is not run because the function failed synchronously.
|
||
|
self.io_loop.add_timeout(self.io_loop.time() + 0.05, self.stop)
|
||
|
result = self.wait()
|
||
|
self.assertIs(result, None)
|
||
|
|
||
|
def test_return_value(self):
|
||
|
with self.assertRaises(ReturnValueIgnoredError):
|
||
|
self.return_value(callback=self.stop)
|
||
|
|
||
|
def test_callback_kw(self):
|
||
|
future = self.sync_future(callback=self.stop)
|
||
|
result = self.wait()
|
||
|
self.assertEqual(result, 42)
|
||
|
self.assertEqual(future.result(), 42)
|
||
|
|
||
|
def test_callback_positional(self):
|
||
|
# When the callback is passed in positionally, future_wrap shouldn't
|
||
|
# add another callback in the kwargs.
|
||
|
future = self.sync_future(self.stop)
|
||
|
result = self.wait()
|
||
|
self.assertEqual(result, 42)
|
||
|
self.assertEqual(future.result(), 42)
|
||
|
|
||
|
def test_no_callback(self):
|
||
|
future = self.sync_future()
|
||
|
self.assertEqual(future.result(), 42)
|
||
|
|
||
|
def test_none_callback_kw(self):
|
||
|
# explicitly pass None as callback
|
||
|
future = self.sync_future(callback=None)
|
||
|
self.assertEqual(future.result(), 42)
|
||
|
|
||
|
def test_none_callback_pos(self):
|
||
|
future = self.sync_future(None)
|
||
|
self.assertEqual(future.result(), 42)
|
||
|
|
||
|
def test_async_future(self):
|
||
|
future = self.async_future()
|
||
|
self.assertFalse(future.done())
|
||
|
self.io_loop.add_future(future, self.stop)
|
||
|
future2 = self.wait()
|
||
|
self.assertIs(future, future2)
|
||
|
self.assertEqual(future.result(), 42)
|
||
|
|
||
|
@gen_test
|
||
|
def test_async_future_gen(self):
|
||
|
result = yield self.async_future()
|
||
|
self.assertEqual(result, 42)
|
||
|
|
||
|
def test_delayed_failure(self):
|
||
|
future = self.delayed_failure()
|
||
|
self.io_loop.add_future(future, self.stop)
|
||
|
future2 = self.wait()
|
||
|
self.assertIs(future, future2)
|
||
|
with self.assertRaises(ZeroDivisionError):
|
||
|
future.result()
|
||
|
|
||
|
def test_kw_only_callback(self):
|
||
|
@return_future
|
||
|
def f(**kwargs):
|
||
|
kwargs['callback'](42)
|
||
|
future = f()
|
||
|
self.assertEqual(future.result(), 42)
|
||
|
|
||
|
def test_error_in_callback(self):
|
||
|
self.sync_future(callback=lambda future: 1 / 0)
|
||
|
# The exception gets caught by our StackContext and will be re-raised
|
||
|
# when we wait.
|
||
|
self.assertRaises(ZeroDivisionError, self.wait)
|
||
|
|
||
|
def test_no_result_future(self):
|
||
|
future = self.no_result_future(self.stop)
|
||
|
result = self.wait()
|
||
|
self.assertIs(result, None)
|
||
|
# result of this future is undefined, but not an error
|
||
|
future.result()
|
||
|
|
||
|
def test_no_result_future_callback(self):
|
||
|
future = self.no_result_future(callback=lambda: self.stop())
|
||
|
result = self.wait()
|
||
|
self.assertIs(result, None)
|
||
|
future.result()
|
||
|
|
||
|
@gen_test
|
||
|
def test_future_traceback(self):
|
||
|
@return_future
|
||
|
@gen.engine
|
||
|
def f(callback):
|
||
|
yield gen.Task(self.io_loop.add_callback)
|
||
|
try:
|
||
|
1 / 0
|
||
|
except ZeroDivisionError:
|
||
|
self.expected_frame = traceback.extract_tb(
|
||
|
sys.exc_info()[2], limit=1)[0]
|
||
|
raise
|
||
|
try:
|
||
|
yield f()
|
||
|
self.fail("didn't get expected exception")
|
||
|
except ZeroDivisionError:
|
||
|
tb = traceback.extract_tb(sys.exc_info()[2])
|
||
|
self.assertIn(self.expected_frame, tb)
|
||
|
|
||
|
# The following series of classes demonstrate and test various styles
|
||
|
# of use, with and without generators and futures.
|
||
|
|
||
|
|
||
|
class CapServer(TCPServer):
|
||
|
def handle_stream(self, stream, address):
|
||
|
logging.info("handle_stream")
|
||
|
self.stream = stream
|
||
|
self.stream.read_until(b"\n", self.handle_read)
|
||
|
|
||
|
def handle_read(self, data):
|
||
|
logging.info("handle_read")
|
||
|
data = to_unicode(data)
|
||
|
if data == data.upper():
|
||
|
self.stream.write(b"error\talready capitalized\n")
|
||
|
else:
|
||
|
# data already has \n
|
||
|
self.stream.write(utf8("ok\t%s" % data.upper()))
|
||
|
self.stream.close()
|
||
|
|
||
|
|
||
|
class CapError(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class BaseCapClient(object):
|
||
|
def __init__(self, port, io_loop):
|
||
|
self.port = port
|
||
|
self.io_loop = io_loop
|
||
|
|
||
|
def process_response(self, data):
|
||
|
status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
|
||
|
if status == 'ok':
|
||
|
return message
|
||
|
else:
|
||
|
raise CapError(message)
|
||
|
|
||
|
|
||
|
class ManualCapClient(BaseCapClient):
|
||
|
def capitalize(self, request_data, callback=None):
|
||
|
logging.info("capitalize")
|
||
|
self.request_data = request_data
|
||
|
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||
|
self.stream.connect(('127.0.0.1', self.port),
|
||
|
callback=self.handle_connect)
|
||
|
self.future = Future()
|
||
|
if callback is not None:
|
||
|
self.future.add_done_callback(
|
||
|
stack_context.wrap(lambda future: callback(future.result())))
|
||
|
return self.future
|
||
|
|
||
|
def handle_connect(self):
|
||
|
logging.info("handle_connect")
|
||
|
self.stream.write(utf8(self.request_data + "\n"))
|
||
|
self.stream.read_until(b'\n', callback=self.handle_read)
|
||
|
|
||
|
def handle_read(self, data):
|
||
|
logging.info("handle_read")
|
||
|
self.stream.close()
|
||
|
try:
|
||
|
self.future.set_result(self.process_response(data))
|
||
|
except CapError as e:
|
||
|
self.future.set_exception(e)
|
||
|
|
||
|
|
||
|
class DecoratorCapClient(BaseCapClient):
|
||
|
@return_future
|
||
|
def capitalize(self, request_data, callback):
|
||
|
logging.info("capitalize")
|
||
|
self.request_data = request_data
|
||
|
self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||
|
self.stream.connect(('127.0.0.1', self.port),
|
||
|
callback=self.handle_connect)
|
||
|
self.callback = callback
|
||
|
|
||
|
def handle_connect(self):
|
||
|
logging.info("handle_connect")
|
||
|
self.stream.write(utf8(self.request_data + "\n"))
|
||
|
self.stream.read_until(b'\n', callback=self.handle_read)
|
||
|
|
||
|
def handle_read(self, data):
|
||
|
logging.info("handle_read")
|
||
|
self.stream.close()
|
||
|
self.callback(self.process_response(data))
|
||
|
|
||
|
|
||
|
class GeneratorCapClient(BaseCapClient):
|
||
|
@return_future
|
||
|
@gen.engine
|
||
|
def capitalize(self, request_data, callback):
|
||
|
logging.info('capitalize')
|
||
|
stream = IOStream(socket.socket(), io_loop=self.io_loop)
|
||
|
logging.info('connecting')
|
||
|
yield gen.Task(stream.connect, ('127.0.0.1', self.port))
|
||
|
stream.write(utf8(request_data + '\n'))
|
||
|
logging.info('reading')
|
||
|
data = yield gen.Task(stream.read_until, b'\n')
|
||
|
logging.info('returning')
|
||
|
stream.close()
|
||
|
callback(self.process_response(data))
|
||
|
|
||
|
|
||
|
class ClientTestMixin(object):
|
||
|
def setUp(self):
|
||
|
super(ClientTestMixin, self).setUp()
|
||
|
self.server = CapServer(io_loop=self.io_loop)
|
||
|
sock, port = bind_unused_port()
|
||
|
self.server.add_sockets([sock])
|
||
|
self.client = self.client_class(io_loop=self.io_loop, port=port)
|
||
|
|
||
|
def tearDown(self):
|
||
|
self.server.stop()
|
||
|
super(ClientTestMixin, self).tearDown()
|
||
|
|
||
|
def test_callback(self):
|
||
|
self.client.capitalize("hello", callback=self.stop)
|
||
|
result = self.wait()
|
||
|
self.assertEqual(result, "HELLO")
|
||
|
|
||
|
def test_callback_error(self):
|
||
|
self.client.capitalize("HELLO", callback=self.stop)
|
||
|
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
|
||
|
|
||
|
def test_future(self):
|
||
|
future = self.client.capitalize("hello")
|
||
|
self.io_loop.add_future(future, self.stop)
|
||
|
self.wait()
|
||
|
self.assertEqual(future.result(), "HELLO")
|
||
|
|
||
|
def test_future_error(self):
|
||
|
future = self.client.capitalize("HELLO")
|
||
|
self.io_loop.add_future(future, self.stop)
|
||
|
self.wait()
|
||
|
self.assertRaisesRegexp(CapError, "already capitalized", future.result)
|
||
|
|
||
|
def test_generator(self):
|
||
|
@gen.engine
|
||
|
def f():
|
||
|
result = yield self.client.capitalize("hello")
|
||
|
self.assertEqual(result, "HELLO")
|
||
|
self.stop()
|
||
|
f()
|
||
|
self.wait()
|
||
|
|
||
|
def test_generator_error(self):
|
||
|
@gen.engine
|
||
|
def f():
|
||
|
with self.assertRaisesRegexp(CapError, "already capitalized"):
|
||
|
yield self.client.capitalize("HELLO")
|
||
|
self.stop()
|
||
|
f()
|
||
|
self.wait()
|
||
|
|
||
|
|
||
|
class ManualClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
|
||
|
client_class = ManualCapClient
|
||
|
|
||
|
|
||
|
class DecoratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
|
||
|
client_class = DecoratorCapClient
|
||
|
|
||
|
|
||
|
class GeneratorClientTest(ClientTestMixin, AsyncTestCase, LogTrapTestCase):
|
||
|
client_class = GeneratorCapClient
|