87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
from tornado.concurrent import Future
|
|
from tornado import gen
|
|
from tornado.httpclient import HTTPError
|
|
from tornado.log import gen_log
|
|
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
|
|
from tornado.web import Application, RequestHandler
|
|
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
|
|
|
|
|
|
class EchoHandler(WebSocketHandler):
|
|
def initialize(self, close_future):
|
|
self.close_future = close_future
|
|
|
|
def on_message(self, message):
|
|
self.write_message(message, isinstance(message, bytes))
|
|
|
|
def on_close(self):
|
|
self.close_future.set_result(None)
|
|
|
|
|
|
class NonWebSocketHandler(RequestHandler):
|
|
def get(self):
|
|
self.write('ok')
|
|
|
|
|
|
class WebSocketTest(AsyncHTTPTestCase):
|
|
def get_app(self):
|
|
self.close_future = Future()
|
|
return Application([
|
|
('/echo', EchoHandler, dict(close_future=self.close_future)),
|
|
('/non_ws', NonWebSocketHandler),
|
|
])
|
|
|
|
@gen_test
|
|
def test_websocket_gen(self):
|
|
ws = yield websocket_connect(
|
|
'ws://localhost:%d/echo' % self.get_http_port(),
|
|
io_loop=self.io_loop)
|
|
ws.write_message('hello')
|
|
response = yield ws.read_message()
|
|
self.assertEqual(response, 'hello')
|
|
|
|
def test_websocket_callbacks(self):
|
|
websocket_connect(
|
|
'ws://localhost:%d/echo' % self.get_http_port(),
|
|
io_loop=self.io_loop, callback=self.stop)
|
|
ws = self.wait().result()
|
|
ws.write_message('hello')
|
|
ws.read_message(self.stop)
|
|
response = self.wait().result()
|
|
self.assertEqual(response, 'hello')
|
|
|
|
@gen_test
|
|
def test_websocket_http_fail(self):
|
|
with self.assertRaises(HTTPError) as cm:
|
|
yield websocket_connect(
|
|
'ws://localhost:%d/notfound' % self.get_http_port(),
|
|
io_loop=self.io_loop)
|
|
self.assertEqual(cm.exception.code, 404)
|
|
|
|
@gen_test
|
|
def test_websocket_http_success(self):
|
|
with self.assertRaises(WebSocketError):
|
|
yield websocket_connect(
|
|
'ws://localhost:%d/non_ws' % self.get_http_port(),
|
|
io_loop=self.io_loop)
|
|
|
|
@gen_test
|
|
def test_websocket_network_fail(self):
|
|
sock, port = bind_unused_port()
|
|
sock.close()
|
|
with self.assertRaises(HTTPError) as cm:
|
|
with ExpectLog(gen_log, ".*"):
|
|
yield websocket_connect(
|
|
'ws://localhost:%d/' % port,
|
|
io_loop=self.io_loop,
|
|
connect_timeout=0.01)
|
|
self.assertEqual(cm.exception.code, 599)
|
|
|
|
@gen_test
|
|
def test_websocket_close_buffered_data(self):
|
|
ws = yield websocket_connect(
|
|
'ws://localhost:%d/echo' % self.get_http_port())
|
|
ws.write_message('hello')
|
|
ws.write_message('world')
|
|
ws.stream.close()
|
|
yield self.close_future
|