#!/usr/bin/env python # # Copyright 2012, Google Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above # copyright notice, this list of conditions and the following disclaimer # in the documentation and/or other materials provided with the # distribution. # * Neither the name of Google Inc. nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Tests for msgutil module.""" from __future__ import absolute_import from __future__ import print_function from __future__ import division import random import struct import unittest import zlib from six import iterbytes from six.moves import map from six.moves import range import six.moves.queue import set_sys_path # Update sys.path to locate pywebsocket3 module. from pywebsocket3 import common, msgutil, util from pywebsocket3.extensions import PerMessageDeflateExtensionProcessor from pywebsocket3.stream import ( InvalidUTF8Exception, Stream, StreamOptions, ) from test import mock # We use one fixed nonce for testing instead of cryptographically secure PRNG. _MASKING_NONCE = b'ABCD' def _mask_hybi(frame): if isinstance(frame, six.text_type): Exception('masking does not accept Texts') frame_key = list(iterbytes(_MASKING_NONCE)) frame_key_len = len(frame_key) result = bytearray(frame) count = 0 for i in range(len(result)): result[i] ^= frame_key[count] count = (count + 1) % frame_key_len return _MASKING_NONCE + bytes(result) def _install_extension_processor(processor, request, stream_options): response = processor.get_extension_response() if response is not None: processor.setup_stream_options(stream_options) request.ws_extension_processors.append(processor) def _create_request_from_rawdata(read_data, permessage_deflate_request=None): req = mock.MockRequest(connection=mock.MockConn(read_data)) req.ws_version = common.VERSION_HYBI_LATEST req.ws_extension_processors = [] processor = None if permessage_deflate_request is not None: processor = PerMessageDeflateExtensionProcessor( permessage_deflate_request) stream_options = StreamOptions() if processor is not None: _install_extension_processor(processor, req, stream_options) req.ws_stream = Stream(req, stream_options) return req def _create_request(*frames): """Creates MockRequest using data given as frames. frames will be returned on calling request.connection.read() where request is MockRequest returned by this function. """ read_data = [] for (header, body) in frames: read_data.append(header + _mask_hybi(body)) return _create_request_from_rawdata(b''.join(read_data)) def _create_blocking_request(): """Creates MockRequest. Data written to a MockRequest can be read out by calling request.connection.written_data(). """ req = mock.MockRequest(connection=mock.MockBlockingConn()) req.ws_version = common.VERSION_HYBI_LATEST stream_options = StreamOptions() req.ws_stream = Stream(req, stream_options) return req class BasicMessageTest(unittest.TestCase): """Basic tests for Stream.""" def test_send_message(self): request = _create_request() msgutil.send_message(request, 'Hello') self.assertEqual(b'\x81\x05Hello', request.connection.written_data()) payload = 'a' * 125 request = _create_request() msgutil.send_message(request, payload) self.assertEqual(b'\x81\x7d' + payload.encode('UTF-8'), request.connection.written_data()) def test_send_medium_message(self): payload = 'a' * 126 request = _create_request() msgutil.send_message(request, payload) self.assertEqual(b'\x81\x7e\x00\x7e' + payload.encode('UTF-8'), request.connection.written_data()) payload = 'a' * ((1 << 16) - 1) request = _create_request() msgutil.send_message(request, payload) self.assertEqual(b'\x81\x7e\xff\xff' + payload.encode('UTF-8'), request.connection.written_data()) def test_send_large_message(self): payload = 'a' * (1 << 16) request = _create_request() msgutil.send_message(request, payload) self.assertEqual( b'\x81\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + payload.encode('UTF-8'), request.connection.written_data()) def test_send_message_unicode(self): request = _create_request() msgutil.send_message(request, u'\u65e5') # U+65e5 is encoded as e6,97,a5 in UTF-8 self.assertEqual(b'\x81\x03\xe6\x97\xa5', request.connection.written_data()) def test_send_message_fragments(self): request = _create_request() msgutil.send_message(request, 'Hello', False) msgutil.send_message(request, ' ', False) msgutil.send_message(request, 'World', False) msgutil.send_message(request, '!', True) self.assertEqual(b'\x01\x05Hello\x00\x01 \x00\x05World\x80\x01!', request.connection.written_data()) def test_send_fragments_immediate_zero_termination(self): request = _create_request() msgutil.send_message(request, 'Hello World!', False) msgutil.send_message(request, '', True) self.assertEqual(b'\x01\x0cHello World!\x80\x00', request.connection.written_data()) def test_receive_message(self): request = _create_request((b'\x81\x85', b'Hello'), (b'\x81\x86', b'World!')) self.assertEqual('Hello', msgutil.receive_message(request)) self.assertEqual('World!', msgutil.receive_message(request)) payload = b'a' * 125 request = _create_request((b'\x81\xfd', payload)) self.assertEqual(payload.decode('UTF-8'), msgutil.receive_message(request)) def test_receive_medium_message(self): payload = b'a' * 126 request = _create_request((b'\x81\xfe\x00\x7e', payload)) self.assertEqual(payload.decode('UTF-8'), msgutil.receive_message(request)) payload = b'a' * ((1 << 16) - 1) request = _create_request((b'\x81\xfe\xff\xff', payload)) self.assertEqual(payload.decode('UTF-8'), msgutil.receive_message(request)) def test_receive_large_message(self): payload = b'a' * (1 << 16) request = _create_request( (b'\x81\xff\x00\x00\x00\x00\x00\x01\x00\x00', payload)) self.assertEqual(payload.decode('UTF-8'), msgutil.receive_message(request)) def test_receive_length_not_encoded_using_minimal_number_of_bytes(self): # Log warning on receiving bad payload length field that doesn't use # minimal number of bytes but continue processing. payload = b'a' # 1 byte can be represented without extended payload length field. request = _create_request( (b'\x81\xff\x00\x00\x00\x00\x00\x00\x00\x01', payload)) self.assertEqual(payload.decode('UTF-8'), msgutil.receive_message(request)) def test_receive_message_unicode(self): request = _create_request((b'\x81\x83', b'\xe6\x9c\xac')) # U+672c is encoded as e6,9c,ac in UTF-8 self.assertEqual(u'\u672c', msgutil.receive_message(request)) def test_receive_message_erroneous_unicode(self): # \x80 and \x81 are invalid as UTF-8. request = _create_request((b'\x81\x82', b'\x80\x81')) # Invalid characters should raise InvalidUTF8Exception self.assertRaises(InvalidUTF8Exception, msgutil.receive_message, request) def test_receive_fragments(self): request = _create_request((b'\x01\x85', b'Hello'), (b'\x00\x81', b' '), (b'\x00\x85', b'World'), (b'\x80\x81', b'!')) self.assertEqual('Hello World!', msgutil.receive_message(request)) def test_receive_fragments_unicode(self): # UTF-8 encodes U+6f22 into e6bca2 and U+5b57 into e5ad97. request = _create_request((b'\x01\x82', b'\xe6\xbc'), (b'\x00\x82', b'\xa2\xe5'), (b'\x80\x82', b'\xad\x97')) self.assertEqual(u'\u6f22\u5b57', msgutil.receive_message(request)) def test_receive_fragments_immediate_zero_termination(self): request = _create_request((b'\x01\x8c', b'Hello World!'), (b'\x80\x80', b'')) self.assertEqual('Hello World!', msgutil.receive_message(request)) def test_receive_fragments_duplicate_start(self): request = _create_request((b'\x01\x85', b'Hello'), (b'\x01\x85', b'World')) self.assertRaises(msgutil.InvalidFrameException, msgutil.receive_message, request) def test_receive_fragments_intermediate_but_not_started(self): request = _create_request((b'\x00\x85', b'Hello')) self.assertRaises(msgutil.InvalidFrameException, msgutil.receive_message, request) def test_receive_fragments_end_but_not_started(self): request = _create_request((b'\x80\x85', b'Hello')) self.assertRaises(msgutil.InvalidFrameException, msgutil.receive_message, request) def test_receive_message_discard(self): request = _create_request( (b'\x8f\x86', b'IGNORE'), (b'\x81\x85', b'Hello'), (b'\x8f\x89', b'DISREGARD'), (b'\x81\x86', b'World!')) self.assertRaises(msgutil.UnsupportedFrameException, msgutil.receive_message, request) self.assertEqual('Hello', msgutil.receive_message(request)) self.assertRaises(msgutil.UnsupportedFrameException, msgutil.receive_message, request) self.assertEqual('World!', msgutil.receive_message(request)) def test_receive_close(self): request = _create_request( (b'\x88\x8a', struct.pack('!H', 1000) + b'Good bye')) self.assertEqual(None, msgutil.receive_message(request)) self.assertEqual(1000, request.ws_close_code) self.assertEqual('Good bye', request.ws_close_reason) def test_send_longest_close(self): reason = 'a' * 123 request = _create_request( (b'\x88\xfd', struct.pack('!H', common.STATUS_NORMAL_CLOSURE) + reason.encode('UTF-8'))) request.ws_stream.close_connection(common.STATUS_NORMAL_CLOSURE, reason) self.assertEqual(request.ws_close_code, common.STATUS_NORMAL_CLOSURE) self.assertEqual(request.ws_close_reason, reason) def test_send_close_too_long(self): request = _create_request() self.assertRaises(msgutil.BadOperationException, Stream.close_connection, request.ws_stream, common.STATUS_NORMAL_CLOSURE, 'a' * 124) def test_send_close_inconsistent_code_and_reason(self): request = _create_request() # reason parameter must not be specified when code is None. self.assertRaises(msgutil.BadOperationException, Stream.close_connection, request.ws_stream, None, 'a') def test_send_ping(self): request = _create_request() msgutil.send_ping(request, 'Hello World!') self.assertEqual(b'\x89\x0cHello World!', request.connection.written_data()) def test_send_longest_ping(self): request = _create_request() msgutil.send_ping(request, 'a' * 125) self.assertEqual(b'\x89\x7d' + b'a' * 125, request.connection.written_data()) def test_send_ping_too_long(self): request = _create_request() self.assertRaises(msgutil.BadOperationException, msgutil.send_ping, request, 'a' * 126) def test_receive_ping(self): """Tests receiving a ping control frame.""" def handler(request, message): request.called = True # Stream automatically respond to ping with pong without any action # by application layer. request = _create_request((b'\x89\x85', b'Hello'), (b'\x81\x85', b'World')) self.assertEqual('World', msgutil.receive_message(request)) self.assertEqual(b'\x8a\x05Hello', request.connection.written_data()) request = _create_request((b'\x89\x85', b'Hello'), (b'\x81\x85', b'World')) request.on_ping_handler = handler self.assertEqual('World', msgutil.receive_message(request)) self.assertTrue(request.called) def test_receive_longest_ping(self): request = _create_request((b'\x89\xfd', b'a' * 125), (b'\x81\x85', b'World')) self.assertEqual('World', msgutil.receive_message(request)) self.assertEqual(b'\x8a\x7d' + b'a' * 125, request.connection.written_data()) def test_receive_ping_too_long(self): request = _create_request((b'\x89\xfe\x00\x7e', b'a' * 126)) self.assertRaises(msgutil.InvalidFrameException, msgutil.receive_message, request) def test_receive_pong(self): """Tests receiving a pong control frame.""" def handler(request, message): request.called = True request = _create_request((b'\x8a\x85', b'Hello'), (b'\x81\x85', b'World')) request.on_pong_handler = handler msgutil.send_ping(request, 'Hello') self.assertEqual(b'\x89\x05Hello', request.connection.written_data()) # Valid pong is received, but receive_message won't return for it. self.assertEqual('World', msgutil.receive_message(request)) # Check that nothing was written after receive_message call. self.assertEqual(b'\x89\x05Hello', request.connection.written_data()) self.assertTrue(request.called) def test_receive_unsolicited_pong(self): # Unsolicited pong is allowed from HyBi 07. request = _create_request((b'\x8a\x85', b'Hello'), (b'\x81\x85', b'World')) msgutil.receive_message(request) request = _create_request((b'\x8a\x85', b'Hello'), (b'\x81\x85', b'World')) msgutil.send_ping(request, 'Jumbo') # Body mismatch. msgutil.receive_message(request) def test_ping_cannot_be_fragmented(self): request = _create_request((b'\x09\x85', b'Hello')) self.assertRaises(msgutil.InvalidFrameException, msgutil.receive_message, request) def test_ping_with_too_long_payload(self): request = _create_request((b'\x89\xfe\x01\x00', b'a' * 256)) self.assertRaises(msgutil.InvalidFrameException, msgutil.receive_message, request) class PerMessageDeflateTest(unittest.TestCase): """Tests for permessage-deflate extension.""" def test_response_parameters(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) extension.add_parameter('server_no_context_takeover', None) processor = PerMessageDeflateExtensionProcessor(extension) response = processor.get_extension_response() self.assertTrue(response.has_parameter('server_no_context_takeover')) self.assertEqual( None, response.get_parameter_value('server_no_context_takeover')) extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) extension.add_parameter('client_max_window_bits', None) processor = PerMessageDeflateExtensionProcessor(extension) processor.set_client_max_window_bits(8) processor.set_client_no_context_takeover(True) response = processor.get_extension_response() self.assertEqual( '8', response.get_parameter_value('client_max_window_bits')) self.assertTrue(response.has_parameter('client_no_context_takeover')) self.assertEqual( None, response.get_parameter_value('client_no_context_takeover')) def test_send_message(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, 'Hello') compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] expected = b'\xc1%c' % len(compressed_hello) expected += compressed_hello self.assertEqual(expected, request.connection.written_data()) def test_send_empty_message(self): """Test that an empty message is compressed correctly.""" extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, '') # Payload in binary: 0b00000000 # From LSB, # - 1 bit of BFINAL (0) # - 2 bits of BTYPE (no compression) # - 5 bits of padding self.assertEqual(b'\xc1\x01\x00', request.connection.written_data()) def test_send_message_with_null_character(self): """Test that a simple payload (one null) is framed correctly.""" extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, '\x00') # Payload in binary: 0b01100010 0b00000000 0b00000000 # From LSB, # - 1 bit of BFINAL (0) # - 2 bits of BTYPE (01 that means fixed Huffman) # - 8 bits of the first code (00110000 that is the code for the literal # alphabet 0x00) # - 7 bits of the second code (0000000 that is the code for the # end-of-block) # - 1 bit of BFINAL (0) # - 2 bits of BTYPE (no compression) # - 2 bits of padding self.assertEqual(b'\xc1\x03\x62\x00\x00', request.connection.written_data()) def test_send_two_messages(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, 'Hello') msgutil.send_message(request, 'World') compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) expected = b'' compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] expected += b'\xc1%c' % len(compressed_hello) expected += compressed_hello compressed_world = compress.compress(b'World') compressed_world += compress.flush(zlib.Z_SYNC_FLUSH) compressed_world = compressed_world[:-4] expected += b'\xc1%c' % len(compressed_world) expected += compressed_world self.assertEqual(expected, request.connection.written_data()) def test_send_message_fragmented(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, 'Hello', end=False) msgutil.send_message(request, 'Goodbye', end=False) msgutil.send_message(request, 'World') compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) expected = b'\x41%c' % len(compressed_hello) expected += compressed_hello compressed_goodbye = compress.compress(b'Goodbye') compressed_goodbye += compress.flush(zlib.Z_SYNC_FLUSH) expected += b'\x00%c' % len(compressed_goodbye) expected += compressed_goodbye compressed_world = compress.compress(b'World') compressed_world += compress.flush(zlib.Z_SYNC_FLUSH) compressed_world = compressed_world[:-4] expected += b'\x80%c' % len(compressed_world) expected += compressed_world self.assertEqual(expected, request.connection.written_data()) def test_send_message_fragmented_empty_first_frame(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, '', end=False) msgutil.send_message(request, 'Hello') compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) expected = b'\x41%c' % len(compressed_hello) expected += compressed_hello compressed_empty = compress.compress(b'Hello') compressed_empty += compress.flush(zlib.Z_SYNC_FLUSH) compressed_empty = compressed_empty[:-4] expected += b'\x80%c' % len(compressed_empty) expected += compressed_empty self.assertEqual(expected, request.connection.written_data()) def test_send_message_fragmented_empty_last_frame(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, 'Hello', end=False) msgutil.send_message(request, '') compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) expected = b'\x41%c' % len(compressed_hello) expected += compressed_hello compressed_empty = compress.compress(b'') compressed_empty += compress.flush(zlib.Z_SYNC_FLUSH) compressed_empty = compressed_empty[:-4] expected += b'\x80%c' % len(compressed_empty) expected += compressed_empty self.assertEqual(expected, request.connection.written_data()) def test_send_message_using_small_window(self): common_part = 'abcdefghijklmnopqrstuvwxyz' test_message = common_part + '-' * 30000 + common_part extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) extension.add_parameter('server_max_window_bits', '8') request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) msgutil.send_message(request, test_message) expected_websocket_header_size = 2 expected_websocket_payload_size = 91 actual_frame = request.connection.written_data() self.assertEqual( expected_websocket_header_size + expected_websocket_payload_size, len(actual_frame)) actual_header = actual_frame[0:expected_websocket_header_size] actual_payload = actual_frame[expected_websocket_header_size:] self.assertEqual(b'\xc1%c' % expected_websocket_payload_size, actual_header) decompress = zlib.decompressobj(-8) decompressed_message = decompress.decompress(actual_payload + b'\x00\x00\xff\xff') decompressed_message += decompress.flush() self.assertEqual(test_message, decompressed_message.decode('UTF-8')) self.assertEqual(0, len(decompress.unused_data)) self.assertEqual(0, len(decompress.unconsumed_tail)) def test_send_message_no_context_takeover_parameter(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) extension.add_parameter('server_no_context_takeover', None) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) for i in range(3): msgutil.send_message(request, 'Hello', end=False) msgutil.send_message(request, 'Hello', end=True) compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) first_hello = compress.compress(b'Hello') first_hello += compress.flush(zlib.Z_SYNC_FLUSH) expected = b'\x41%c' % len(first_hello) expected += first_hello second_hello = compress.compress(b'Hello') second_hello += compress.flush(zlib.Z_SYNC_FLUSH) second_hello = second_hello[:-4] expected += b'\x80%c' % len(second_hello) expected += second_hello self.assertEqual(expected + expected + expected, request.connection.written_data()) def test_send_message_fragmented_bfinal(self): extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( b'', permessage_deflate_request=extension) self.assertEqual(1, len(request.ws_extension_processors)) request.ws_extension_processors[0].set_bfinal(True) msgutil.send_message(request, 'Hello', end=False) msgutil.send_message(request, 'World', end=True) expected = b'' compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_FINISH) compressed_hello = compressed_hello + struct.pack('!B', 0) expected += b'\x41%c' % len(compressed_hello) expected += compressed_hello compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_world = compress.compress(b'World') compressed_world += compress.flush(zlib.Z_FINISH) compressed_world = compressed_world + struct.pack('!B', 0) expected += b'\x80%c' % len(compressed_world) expected += compressed_world self.assertEqual(expected, request.connection.written_data()) def test_receive_message_deflate(self): compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_hello = compress.compress(b'Hello') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] data = b'\xc1%c' % (len(compressed_hello) | 0x80) data += _mask_hybi(compressed_hello) # Close frame data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( data, permessage_deflate_request=extension) self.assertEqual('Hello', msgutil.receive_message(request)) self.assertEqual(None, msgutil.receive_message(request)) def test_receive_message_random_section(self): """Test that a compressed message fragmented into lots of chunks is correctly received. """ random.seed(a=0) payload = b''.join( [struct.pack('!B', random.randint(0, 255)) for i in range(1000)]) compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_payload = compress.compress(payload) compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH) compressed_payload = compressed_payload[:-4] # Fragment the compressed payload into lots of frames. bytes_chunked = 0 data = b'' frame_count = 0 chunk_sizes = [] while bytes_chunked < len(compressed_payload): # Make sure that # - the length of chunks are equal or less than 125 so that we can # use 1 octet length header format for all frames. # - at least 10 chunks are created. chunk_size = random.randint( 1, min(125, len(compressed_payload) // 10, len(compressed_payload) - bytes_chunked)) chunk_sizes.append(chunk_size) chunk = compressed_payload[bytes_chunked:bytes_chunked + chunk_size] bytes_chunked += chunk_size first_octet = 0x00 if len(data) == 0: first_octet = first_octet | 0x42 if bytes_chunked == len(compressed_payload): first_octet = first_octet | 0x80 data += b'%c%c' % (first_octet, chunk_size | 0x80) data += _mask_hybi(chunk) frame_count += 1 self.assertTrue(len(chunk_sizes) > 10) # Close frame data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( data, permessage_deflate_request=extension) self.assertEqual(payload, msgutil.receive_message(request)) self.assertEqual(None, msgutil.receive_message(request)) def test_receive_two_messages(self): compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) data = b'' compressed_hello = compress.compress(b'HelloWebSocket') compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH) compressed_hello = compressed_hello[:-4] split_position = len(compressed_hello) // 2 data += b'\x41%c' % (split_position | 0x80) data += _mask_hybi(compressed_hello[:split_position]) data += b'\x80%c' % ((len(compressed_hello) - split_position) | 0x80) data += _mask_hybi(compressed_hello[split_position:]) compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) compressed_world = compress.compress(b'World') compressed_world += compress.flush(zlib.Z_SYNC_FLUSH) compressed_world = compressed_world[:-4] data += b'\xc1%c' % (len(compressed_world) | 0x80) data += _mask_hybi(compressed_world) # Close frame data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( data, permessage_deflate_request=extension) self.assertEqual('HelloWebSocket', msgutil.receive_message(request)) self.assertEqual('World', msgutil.receive_message(request)) self.assertEqual(None, msgutil.receive_message(request)) def test_receive_message_mixed_btype(self): """Test that a message compressed using lots of DEFLATE blocks with various flush mode is correctly received. """ random.seed(a=0) payload = b''.join( [struct.pack('!B', random.randint(0, 255)) for i in range(1000)]) compress = None # Fragment the compressed payload into lots of frames. bytes_chunked = 0 compressed_payload = b'' chunk_sizes = [] methods = [] sync_used = False finish_used = False while bytes_chunked < len(payload): # Make sure at least 10 chunks are created. chunk_size = random.randint(1, min(100, len(payload) - bytes_chunked)) chunk_sizes.append(chunk_size) chunk = payload[bytes_chunked:bytes_chunked + chunk_size] bytes_chunked += chunk_size if compress is None: compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS) if bytes_chunked == len(payload): compressed_payload += compress.compress(chunk) compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH) compressed_payload = compressed_payload[:-4] else: method = random.randint(0, 1) methods.append(method) if method == 0: compressed_payload += compress.compress(chunk) compressed_payload += compress.flush(zlib.Z_SYNC_FLUSH) sync_used = True else: compressed_payload += compress.compress(chunk) compressed_payload += compress.flush(zlib.Z_FINISH) compress = None finish_used = True self.assertTrue(len(chunk_sizes) > 10) self.assertTrue(sync_used) self.assertTrue(finish_used) self.assertTrue(125 < len(compressed_payload)) self.assertTrue(len(compressed_payload) < 65536) data = b'\xc2\xfe' + struct.pack('!H', len(compressed_payload)) data += _mask_hybi(compressed_payload) # Close frame data += b'\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + b'Good bye') extension = common.ExtensionParameter( common.PERMESSAGE_DEFLATE_EXTENSION) request = _create_request_from_rawdata( data, permessage_deflate_request=extension) self.assertEqual(payload, msgutil.receive_message(request)) self.assertEqual(None, msgutil.receive_message(request)) class MessageReceiverTest(unittest.TestCase): """Tests the Stream class using MessageReceiver.""" def test_queue(self): request = _create_blocking_request() receiver = msgutil.MessageReceiver(request) self.assertEqual(None, receiver.receive_nowait()) request.connection.put_bytes(b'\x81\x86' + _mask_hybi(b'Hello!')) self.assertEqual('Hello!', receiver.receive()) def test_onmessage(self): onmessage_queue = six.moves.queue.Queue() def onmessage_handler(message): onmessage_queue.put(message) request = _create_blocking_request() receiver = msgutil.MessageReceiver(request, onmessage_handler) request.connection.put_bytes(b'\x81\x86' + _mask_hybi(b'Hello!')) self.assertEqual('Hello!', onmessage_queue.get()) class MessageSenderTest(unittest.TestCase): """Tests the Stream class using MessageSender.""" def test_send(self): request = _create_blocking_request() sender = msgutil.MessageSender(request) sender.send('World') self.assertEqual(b'\x81\x05World', request.connection.written_data()) def test_send_nowait(self): # Use a queue to check the bytes written by MessageSender. # request.connection.written_data() cannot be used here because # MessageSender runs in a separate thread. send_queue = six.moves.queue.Queue() def write(bytes): send_queue.put(bytes) request = _create_blocking_request() request.connection.write = write sender = msgutil.MessageSender(request) sender.send_nowait('Hello') sender.send_nowait('World') self.assertEqual(b'\x81\x05Hello', send_queue.get()) self.assertEqual(b'\x81\x05World', send_queue.get()) if __name__ == '__main__': unittest.main() # vi:sts=4 sw=4 et