Ticket #4173: 4173-3.patch

File 4173-3.patch, 18.7 KB (added by MostAwesomeDude, 2 years ago)

#4173, with tests, hybrid approach

  • twisted/web/websockets.py

     
     1# Copyright (c) 2011-2012 Oregon State University Open Source Lab 
     2# 
     3# Permission is hereby granted, free of charge, to any person obtaining a copy 
     4# of this software and associated documentation files (the "Software"), to 
     5# deal in the Software without restriction, including without limitation the 
     6# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 
     7# sell copies of the Software, and to permit persons to whom the Software is 
     8# furnished to do so, subject to the following conditions: 
     9# 
     10#    The above copyright notice and this permission notice shall be included 
     11#    in all copies or substantial portions of the Software. 
     12# 
     13#    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
     14#    OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 
     15#    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN 
     16#    NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 
     17#    DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 
     18#    OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE 
     19#    USE OR OTHER DEALINGS IN THE SOFTWARE. 
     20 
     21""" 
     22The WebSockets protocol (RFC 6455), provided as a resource which wraps a 
     23protocol. 
     24""" 
     25 
     26from base64 import b64encode, b64decode 
     27from hashlib import sha1 
     28from struct import pack, unpack 
     29 
     30from twisted.protocols.policies import ProtocolWrapper, WrappingFactory 
     31from twisted.python import log 
     32from twisted.web.error import NoResource 
     33from twisted.web.resource import IResource 
     34from twisted.web.server import NOT_DONE_YET 
     35from zope.interface import implements 
     36 
     37class WSException(Exception): 
     38    """ 
     39    Something stupid happened here. 
     40 
     41    If this class escapes txWS, then something stupid happened in multiple 
     42    places. 
     43    """ 
     44 
     45# Control frame specifiers. Some versions of WS have control signals sent 
     46# in-band. Adorable, right? 
     47 
     48NORMAL, CLOSE, PING, PONG = range(4) 
     49 
     50opcode_types = { 
     51    0x0: NORMAL, 
     52    0x1: NORMAL, 
     53    0x2: NORMAL, 
     54    0x8: CLOSE, 
     55    0x9: PING, 
     56    0xa: PONG, 
     57} 
     58 
     59encoders = { 
     60    "base64": b64encode, 
     61} 
     62 
     63decoders = { 
     64    "base64": b64decode, 
     65} 
     66 
     67# Authentication for WS. 
     68 
     69def make_accept(key): 
     70    """ 
     71    Create an "accept" response for a given key. 
     72 
     73    This dance is expected to somehow magically make WebSockets secure. 
     74    """ 
     75 
     76    guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 
     77 
     78    return sha1("%s%s" % (key, guid)).digest().encode("base64").strip() 
     79 
     80# Frame helpers. 
     81# Separated out to make unit testing a lot easier. 
     82# Frames are bonghits in newer WS versions, so helpers are appreciated. 
     83 
     84def mask(buf, key): 
     85    """ 
     86    Mask or unmask a buffer of bytes with a masking key. 
     87 
     88    The key must be exactly four bytes long. 
     89    """ 
     90 
     91    # This is super-secure, I promise~ 
     92    key = [ord(i) for i in key] 
     93    buf = list(buf) 
     94    for i, char in enumerate(buf): 
     95        buf[i] = chr(ord(char) ^ key[i % 4]) 
     96    return "".join(buf) 
     97 
     98def make_hybi07_frame(buf, opcode=0x1): 
     99    """ 
     100    Make a HyBi-07 frame. 
     101 
     102    This function always creates unmasked frames, and attempts to use the 
     103    smallest possible lengths. 
     104    """ 
     105 
     106    if len(buf) > 0xffff: 
     107        length = "\x7f%s" % pack(">Q", len(buf)) 
     108    elif len(buf) > 0x7d: 
     109        length = "\x7e%s" % pack(">H", len(buf)) 
     110    else: 
     111        length = chr(len(buf)) 
     112 
     113    # Always make a normal packet. 
     114    header = chr(0x80 | opcode) 
     115    frame = "%s%s%s" % (header, length, buf) 
     116    return frame 
     117 
     118def parse_hybi07_frames(buf): 
     119    """ 
     120    Parse HyBi-07 frames in a highly compliant manner. 
     121    """ 
     122 
     123    start = 0 
     124    frames = [] 
     125 
     126    while True: 
     127        # If there's not at least two bytes in the buffer, bail. 
     128        if len(buf) - start < 2: 
     129            break 
     130 
     131        # Grab the header. This single byte holds some flags nobody cares 
     132        # about, and an opcode which nobody cares about. 
     133        header = ord(buf[start]) 
     134        if header & 0x70: 
     135            # At least one of the reserved flags is set. Pork chop sandwiches! 
     136            raise WSException("Reserved flag in HyBi-07 frame (%d)" % header) 
     137            frames.append(("", CLOSE)) 
     138            return frames, buf 
     139 
     140        # Get the opcode, and translate it to a local enum which we actually 
     141        # care about. 
     142        opcode = header & 0xf 
     143        try: 
     144            opcode = opcode_types[opcode] 
     145        except KeyError: 
     146            raise WSException("Unknown opcode %d in HyBi-07 frame" % opcode) 
     147 
     148        # Get the payload length and determine whether we need to look for an 
     149        # extra length. 
     150        length = ord(buf[start + 1]) 
     151        masked = length & 0x80 
     152        length &= 0x7f 
     153 
     154        # The offset we're gonna be using to walk through the frame. We use 
     155        # this because the offset is variable depending on the length and 
     156        # mask. 
     157        offset = 2 
     158 
     159        # Extra length fields. 
     160        if length == 0x7e: 
     161            if len(buf) - start < 4: 
     162                break 
     163 
     164            length = buf[start + 2:start + 4] 
     165            length = unpack(">H", length)[0] 
     166            offset += 2 
     167        elif length == 0x7f: 
     168            if len(buf) - start < 10: 
     169                break 
     170 
     171            # Protocol bug: The top bit of this long long *must* be cleared; 
     172            # that is, it is expected to be interpreted as signed. That's 
     173            # fucking stupid, if you don't mind me saying so, and so we're 
     174            # interpreting it as unsigned anyway. If you wanna send exabytes 
     175            # of data down the wire, then go ahead! 
     176            length = buf[start + 2:start + 10] 
     177            length = unpack(">Q", length)[0] 
     178            offset += 8 
     179 
     180        if masked: 
     181            if len(buf) - (start + offset) < 4: 
     182                break 
     183 
     184            key = buf[start + offset:start + offset + 4] 
     185            offset += 4 
     186 
     187        if len(buf) - (start + offset) < length: 
     188            break 
     189 
     190        data = buf[start + offset:start + offset + length] 
     191 
     192        if masked: 
     193            data = mask(data, key) 
     194 
     195        if opcode == CLOSE: 
     196            if len(data) >= 2: 
     197                # Gotta unpack the opcode and return usable data here. 
     198                data = unpack(">H", data[:2])[0], data[2:] 
     199            else: 
     200                # No reason given; use generic data. 
     201                data = 1000, "No reason given" 
     202 
     203        frames.append((opcode, data)) 
     204        start += offset + length 
     205 
     206    return frames, buf[start:] 
     207 
     208class WebSocketsProtocol(ProtocolWrapper): 
     209    """ 
     210    Protocol which wraps another protocol to provide a WebSockets transport 
     211    layer. 
     212    """ 
     213 
     214    buf = "" 
     215    codec = None 
     216 
     217    def __init__(self, *args, **kwargs): 
     218        ProtocolWrapper.__init__(self, *args, **kwargs) 
     219        self.pending_frames = [] 
     220 
     221    def parseFrames(self): 
     222        """ 
     223        Find frames in incoming data and pass them to the underlying protocol. 
     224        """ 
     225 
     226        try: 
     227            frames, self.buf = parse_hybi07_frames(self.buf) 
     228        except WSException, wse: 
     229            # Couldn't parse all the frames, something went wrong, let's bail. 
     230            self.close(wse.args[0]) 
     231            return 
     232 
     233        for frame in frames: 
     234            opcode, data = frame 
     235            if opcode == NORMAL: 
     236                # Business as usual. Decode the frame, if we have a decoder. 
     237                if self.codec: 
     238                    data = decoders[self.codec](data) 
     239                # Pass the frame to the underlying protocol. 
     240                ProtocolWrapper.dataReceived(self, data) 
     241            elif opcode == CLOSE: 
     242                # The other side wants us to close. I wonder why? 
     243                reason, text = data 
     244                log.msg("Closing connection: %r (%d)" % (text, reason)) 
     245 
     246                # Close the connection. 
     247                self.close() 
     248 
     249    def sendFrames(self): 
     250        """ 
     251        Send all pending frames. 
     252        """ 
     253 
     254        for frame in self.pending_frames: 
     255            # Encode the frame before sending it. 
     256            if self.codec: 
     257                frame = encoders[self.codec](frame) 
     258            packet = make_hybi07_frame(frame) 
     259            self.transport.write(packet) 
     260        self.pending_frames = [] 
     261 
     262    def dataReceived(self, data): 
     263        self.buf += data 
     264 
     265        self.parseFrames() 
     266 
     267        # Kick any pending frames. This is needed because frames might have 
     268        # started piling up early; we can get write()s from our protocol above 
     269        # when they makeConnection() immediately, before our browser client 
     270        # actually sends any data. In those cases, we need to manually kick 
     271        # pending frames. 
     272        if self.pending_frames: 
     273            self.sendFrames() 
     274 
     275    def write(self, data): 
     276        """ 
     277        Write to the transport. 
     278 
     279        This method will only be called by the underlying protocol. 
     280        """ 
     281 
     282        self.pending_frames.append(data) 
     283        self.sendFrames() 
     284 
     285    def writeSequence(self, data): 
     286        """ 
     287        Write a sequence of data to the transport. 
     288 
     289        This method will only be called by the underlying protocol. 
     290        """ 
     291 
     292        self.pending_frames.extend(data) 
     293        self.sendFrames() 
     294 
     295    def close(self, reason=""): 
     296        """ 
     297        Close the connection. 
     298 
     299        This includes telling the other side we're closing the connection. 
     300 
     301        If the other side didn't signal that the connection is being closed, 
     302        then we might not see their last message, but since their last message 
     303        should, according to the spec, be a simple acknowledgement, it 
     304        shouldn't be a problem. 
     305        """ 
     306 
     307        # Send a closing frame. It's only polite. (And might keep the browser 
     308        # from hanging.) 
     309        frame = make_hybi07_frame(reason, opcode=0x8) 
     310        self.transport.write(frame) 
     311 
     312        self.loseConnection() 
     313 
     314class WebSocketsFactory(WrappingFactory): 
     315    """ 
     316    Factory which wraps another factory to provide WebSockets frames for all 
     317    of its protocols. 
     318 
     319    This factory does not provide the HTTP headers required to perform a 
     320    WebSockets handshake; see C{WebSocketsResource}. 
     321    """ 
     322 
     323    protocol = WebSocketsProtocol 
     324 
     325class WebSocketsResource(object): 
     326 
     327    implements(IResource) 
     328 
     329    isLeaf = True 
     330 
     331    def __init__(self, factory): 
     332        self._factory = WebSocketsFactory(factory) 
     333 
     334    def getChildWithDefault(self, name, request): 
     335        return NoResource("No such child resource.") 
     336 
     337    def putChild(self, path, child): 
     338        pass 
     339 
     340    def render(self, request): 
     341        """ 
     342        Render a request. 
     343 
     344        We're not actually rendering a request. We are secretly going to 
     345        handle a WebSockets connection instead. 
     346        """ 
     347 
     348        # If we fail at all, we're gonna fail with 400 and no response. 
     349        # You might want to pop open the RFC and read along. 
     350        failed = False 
     351 
     352        if request.method != "GET": 
     353            # 4.2.1.1 GET is required. 
     354            failed = True 
     355 
     356        upgrade = request.getHeader("Upgrade") 
     357        if upgrade is None or "websocket" not in upgrade.lower(): 
     358            # 4.2.1.3 Upgrade: WebSocket is required. 
     359            failed = True 
     360 
     361        connection = request.getHeader("Connection") 
     362        if connection is None or "upgrade" not in connection.lower(): 
     363            # 4.2.1.4 Connection: Upgrade is required. 
     364            failed = True 
     365 
     366        key = request.getHeader("Sec-WebSocket-Key") 
     367        if key is None: 
     368            # 4.2.1.5 The challenge key is required. 
     369            failed = True 
     370 
     371        version = request.getHeader("Sec-WebSocket-Version") 
     372        if version is None or version != "13": 
     373            # 4.2.1.6 Only version 13 works. 
     374            failed = True 
     375            # 4.4 Forward-compatible version checking. 
     376            request.setHeader("Sec-WebSocket-Version", "13") 
     377 
     378        # Stash host and origin for those browsers that care about it. 
     379        host = request.getHeader("Host") 
     380        origin = request.getHeader("Origin") 
     381 
     382        # Check whether a codec is needed. WS calls this a "protocol" for 
     383        # reasons I cannot fathom. 
     384        protocol = request.getHeader("Sec-WebSocket-Protocol") 
     385 
     386        if protocol: 
     387            if protocol not in encoders or protocol not in decoders: 
     388                log.msg("Protocol %s is not implemented" % protocol) 
     389                failed = True 
     390 
     391        if failed: 
     392            request.setResponseCode(400) 
     393            return "" 
     394 
     395        # We are going to finish this handshake. We will return a valid status 
     396        # code. 
     397        # 4.2.2.5.1 101 Switching Protocols 
     398        request.setResponseCode(101) 
     399        # 4.2.2.5.2 Upgrade: websocket 
     400        request.setHeader("Upgrade", "WebSocket") 
     401        # 4.2.2.5.3 Connection: Upgrade 
     402        request.setHeader("Connection", "Upgrade") 
     403        # 4.2.2.5.4 Response to the key challenge 
     404        request.setHeader("Sec-WebSocket-Accept", make_accept(key)) 
     405 
     406        # Provoke request into flushing headers and finishing the handshake. 
     407        request.write("") 
     408 
     409        # And now take matters into our own hands. We shall manage the 
     410        # transport's lifecycle. 
     411        transport, request.transport = request.transport, None 
     412 
     413        # Connect the transport to our factory, and make things go. We need to 
     414        # do some stupid stuff here; see #3204, which could fix it. 
     415        protocol = self._factory.buildProtocol(transport.getPeer()) 
     416        transport.protocol = protocol 
     417        protocol.makeConnection(transport) 
     418 
     419        return NOT_DONE_YET 
  • twisted/web/test/test_websockets.py

     
     1from twisted.trial import unittest 
     2 
     3from twisted.web.websockets import (make_accept, mask, CLOSE, NORMAL, PING, 
     4    PONG, parse_hybi07_frames) 
     5 
     6class TestKeys(unittest.TestCase): 
     7 
     8    def test_make_accept_rfc(self): 
     9        """ 
     10        Test ``make_accept()`` using the keys listed in the RFC for HyBi-07 
     11        through HyBi-10. 
     12        """ 
     13 
     14        key = "dGhlIHNhbXBsZSBub25jZQ==" 
     15 
     16        self.assertEqual(make_accept(key), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=") 
     17 
     18    def test_make_accept_wikipedia(self): 
     19        """ 
     20        Test ``make_accept()`` using the keys listed on Wikipedia. 
     21        """ 
     22 
     23        key = "x3JJHMbDL1EzLkh9GBhXDw==" 
     24 
     25        self.assertEqual(make_accept(key), "HSmrc0sMlYUkAGmm5OPpG2HaGWk=") 
     26 
     27class TestHyBi07Helpers(unittest.TestCase): 
     28    """ 
     29    HyBi-07 is best understood as a large family of helper functions which 
     30    work together, somewhat dysfunctionally, to produce a mediocre 
     31    Thanksgiving every other year. 
     32    """ 
     33 
     34    def test_mask_noop(self): 
     35        key = "\x00\x00\x00\x00" 
     36        self.assertEqual(mask("Test", key), "Test") 
     37 
     38    def test_mask_noop_long(self): 
     39        key = "\x00\x00\x00\x00" 
     40        self.assertEqual(mask("LongTest", key), "LongTest") 
     41 
     42    def test_parse_hybi07_unmasked_text(self): 
     43        """ 
     44        From HyBi-10, 4.7. 
     45        """ 
     46 
     47        frame = "\x81\x05Hello" 
     48        frames, buf = parse_hybi07_frames(frame) 
     49        self.assertEqual(len(frames), 1) 
     50        self.assertEqual(frames[0], (NORMAL, "Hello")) 
     51        self.assertEqual(buf, "") 
     52 
     53    def test_parse_hybi07_masked_text(self): 
     54        """ 
     55        From HyBi-10, 4.7. 
     56        """ 
     57 
     58        frame = "\x81\x857\xfa!=\x7f\x9fMQX" 
     59        frames, buf = parse_hybi07_frames(frame) 
     60        self.assertEqual(len(frames), 1) 
     61        self.assertEqual(frames[0], (NORMAL, "Hello")) 
     62        self.assertEqual(buf, "") 
     63 
     64    def test_parse_hybi07_unmasked_text_fragments(self): 
     65        """ 
     66        We don't care about fragments. We are totally unfazed. 
     67 
     68        From HyBi-10, 4.7. 
     69        """ 
     70 
     71        frame = "\x01\x03Hel\x80\x02lo" 
     72        frames, buf = parse_hybi07_frames(frame) 
     73        self.assertEqual(len(frames), 2) 
     74        self.assertEqual(frames[0], (NORMAL, "Hel")) 
     75        self.assertEqual(frames[1], (NORMAL, "lo")) 
     76        self.assertEqual(buf, "") 
     77 
     78    def test_parse_hybi07_ping(self): 
     79        """ 
     80        From HyBi-10, 4.7. 
     81        """ 
     82 
     83        frame = "\x89\x05Hello" 
     84        frames, buf = parse_hybi07_frames(frame) 
     85        self.assertEqual(len(frames), 1) 
     86        self.assertEqual(frames[0], (PING, "Hello")) 
     87        self.assertEqual(buf, "") 
     88 
     89    def test_parse_hybi07_pong(self): 
     90        """ 
     91        From HyBi-10, 4.7. 
     92        """ 
     93 
     94        frame = "\x8a\x05Hello" 
     95        frames, buf = parse_hybi07_frames(frame) 
     96        self.assertEqual(len(frames), 1) 
     97        self.assertEqual(frames[0], (PONG, "Hello")) 
     98        self.assertEqual(buf, "") 
     99 
     100    def test_parse_hybi07_close_empty(self): 
     101        """ 
     102        A HyBi-07 close packet may have no body. In that case, it should use 
     103        the generic error code 1000, and have no reason. 
     104        """ 
     105 
     106        frame = "\x88\x00" 
     107        frames, buf = parse_hybi07_frames(frame) 
     108        self.assertEqual(len(frames), 1) 
     109        self.assertEqual(frames[0], (CLOSE, (1000, "No reason given"))) 
     110        self.assertEqual(buf, "") 
     111 
     112    def test_parse_hybi07_close_reason(self): 
     113        """ 
     114        A HyBi-07 close packet must have its first two bytes be a numeric 
     115        error code, and may optionally include trailing text explaining why 
     116        the connection was closed. 
     117        """ 
     118 
     119        frame = "\x88\x0b\x03\xe8No reason" 
     120        frames, buf = parse_hybi07_frames(frame) 
     121        self.assertEqual(len(frames), 1) 
     122        self.assertEqual(frames[0], (CLOSE, (1000, "No reason"))) 
     123        self.assertEqual(buf, "") 
     124 
     125    def test_parse_hybi07_partial_no_length(self): 
     126        frame = "\x81" 
     127        frames, buf = parse_hybi07_frames(frame) 
     128        self.assertFalse(frames) 
     129        self.assertEqual(buf, "\x81") 
     130 
     131    def test_parse_hybi07_partial_truncated_length_int(self): 
     132        frame = "\x81\xfe" 
     133        frames, buf = parse_hybi07_frames(frame) 
     134        self.assertFalse(frames) 
     135        self.assertEqual(buf, "\x81\xfe") 
     136 
     137    def test_parse_hybi07_partial_truncated_length_double(self): 
     138        frame = "\x81\xff" 
     139        frames, buf = parse_hybi07_frames(frame) 
     140        self.assertFalse(frames) 
     141        self.assertEqual(buf, "\x81\xff") 
     142 
     143    def test_parse_hybi07_partial_no_data(self): 
     144        frame = "\x81\x05" 
     145        frames, buf = parse_hybi07_frames(frame) 
     146        self.assertFalse(frames) 
     147        self.assertEqual(buf, "\x81\x05") 
     148 
     149    def test_parse_hybi07_partial_truncated_data(self): 
     150        frame = "\x81\x05Hel" 
     151        frames, buf = parse_hybi07_frames(frame) 
     152        self.assertFalse(frames) 
     153        self.assertEqual(buf, "\x81\x05Hel")