Ticket #4173: 4173-3.patch

File 4173-3.patch, 18.7 KB (added by MostAwesomeDude, 4 years ago)

#4173, with tests, hybrid approach

  • twisted/web/websockets.py

     
     1# Copyright (c) 2011-2012 Oregon State University Open Source Lab
     2#
     3# Permission is hereby granted, free of charge, to any person obtaining a copy
     4# of this software and associated documentation files (the "Software"), to
     5# deal in the Software without restriction, including without limitation the
     6# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
     7# sell copies of the Software, and to permit persons to whom the Software is
     8# furnished to do so, subject to the following conditions:
     9#
     10#    The above copyright notice and this permission notice shall be included
     11#    in all copies or substantial portions of the Software.
     12#
     13#    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
     14#    OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     15#    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
     16#    NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
     17#    DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
     18#    OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
     19#    USE OR OTHER DEALINGS IN THE SOFTWARE.
     20
     21"""
     22The WebSockets protocol (RFC 6455), provided as a resource which wraps a
     23protocol.
     24"""
     25
     26from base64 import b64encode, b64decode
     27from hashlib import sha1
     28from struct import pack, unpack
     29
     30from twisted.protocols.policies import ProtocolWrapper, WrappingFactory
     31from twisted.python import log
     32from twisted.web.error import NoResource
     33from twisted.web.resource import IResource
     34from twisted.web.server import NOT_DONE_YET
     35from zope.interface import implements
     36
     37class WSException(Exception):
     38    """
     39    Something stupid happened here.
     40
     41    If this class escapes txWS, then something stupid happened in multiple
     42    places.
     43    """
     44
     45# Control frame specifiers. Some versions of WS have control signals sent
     46# in-band. Adorable, right?
     47
     48NORMAL, CLOSE, PING, PONG = range(4)
     49
     50opcode_types = {
     51    0x0: NORMAL,
     52    0x1: NORMAL,
     53    0x2: NORMAL,
     54    0x8: CLOSE,
     55    0x9: PING,
     56    0xa: PONG,
     57}
     58
     59encoders = {
     60    "base64": b64encode,
     61}
     62
     63decoders = {
     64    "base64": b64decode,
     65}
     66
     67# Authentication for WS.
     68
     69def make_accept(key):
     70    """
     71    Create an "accept" response for a given key.
     72
     73    This dance is expected to somehow magically make WebSockets secure.
     74    """
     75
     76    guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
     77
     78    return sha1("%s%s" % (key, guid)).digest().encode("base64").strip()
     79
     80# Frame helpers.
     81# Separated out to make unit testing a lot easier.
     82# Frames are bonghits in newer WS versions, so helpers are appreciated.
     83
     84def mask(buf, key):
     85    """
     86    Mask or unmask a buffer of bytes with a masking key.
     87
     88    The key must be exactly four bytes long.
     89    """
     90
     91    # This is super-secure, I promise~
     92    key = [ord(i) for i in key]
     93    buf = list(buf)
     94    for i, char in enumerate(buf):
     95        buf[i] = chr(ord(char) ^ key[i % 4])
     96    return "".join(buf)
     97
     98def make_hybi07_frame(buf, opcode=0x1):
     99    """
     100    Make a HyBi-07 frame.
     101
     102    This function always creates unmasked frames, and attempts to use the
     103    smallest possible lengths.
     104    """
     105
     106    if len(buf) > 0xffff:
     107        length = "\x7f%s" % pack(">Q", len(buf))
     108    elif len(buf) > 0x7d:
     109        length = "\x7e%s" % pack(">H", len(buf))
     110    else:
     111        length = chr(len(buf))
     112
     113    # Always make a normal packet.
     114    header = chr(0x80 | opcode)
     115    frame = "%s%s%s" % (header, length, buf)
     116    return frame
     117
     118def parse_hybi07_frames(buf):
     119    """
     120    Parse HyBi-07 frames in a highly compliant manner.
     121    """
     122
     123    start = 0
     124    frames = []
     125
     126    while True:
     127        # If there's not at least two bytes in the buffer, bail.
     128        if len(buf) - start < 2:
     129            break
     130
     131        # Grab the header. This single byte holds some flags nobody cares
     132        # about, and an opcode which nobody cares about.
     133        header = ord(buf[start])
     134        if header & 0x70:
     135            # At least one of the reserved flags is set. Pork chop sandwiches!
     136            raise WSException("Reserved flag in HyBi-07 frame (%d)" % header)
     137            frames.append(("", CLOSE))
     138            return frames, buf
     139
     140        # Get the opcode, and translate it to a local enum which we actually
     141        # care about.
     142        opcode = header & 0xf
     143        try:
     144            opcode = opcode_types[opcode]
     145        except KeyError:
     146            raise WSException("Unknown opcode %d in HyBi-07 frame" % opcode)
     147
     148        # Get the payload length and determine whether we need to look for an
     149        # extra length.
     150        length = ord(buf[start + 1])
     151        masked = length & 0x80
     152        length &= 0x7f
     153
     154        # The offset we're gonna be using to walk through the frame. We use
     155        # this because the offset is variable depending on the length and
     156        # mask.
     157        offset = 2
     158
     159        # Extra length fields.
     160        if length == 0x7e:
     161            if len(buf) - start < 4:
     162                break
     163
     164            length = buf[start + 2:start + 4]
     165            length = unpack(">H", length)[0]
     166            offset += 2
     167        elif length == 0x7f:
     168            if len(buf) - start < 10:
     169                break
     170
     171            # Protocol bug: The top bit of this long long *must* be cleared;
     172            # that is, it is expected to be interpreted as signed. That's
     173            # fucking stupid, if you don't mind me saying so, and so we're
     174            # interpreting it as unsigned anyway. If you wanna send exabytes
     175            # of data down the wire, then go ahead!
     176            length = buf[start + 2:start + 10]
     177            length = unpack(">Q", length)[0]
     178            offset += 8
     179
     180        if masked:
     181            if len(buf) - (start + offset) < 4:
     182                break
     183
     184            key = buf[start + offset:start + offset + 4]
     185            offset += 4
     186
     187        if len(buf) - (start + offset) < length:
     188            break
     189
     190        data = buf[start + offset:start + offset + length]
     191
     192        if masked:
     193            data = mask(data, key)
     194
     195        if opcode == CLOSE:
     196            if len(data) >= 2:
     197                # Gotta unpack the opcode and return usable data here.
     198                data = unpack(">H", data[:2])[0], data[2:]
     199            else:
     200                # No reason given; use generic data.
     201                data = 1000, "No reason given"
     202
     203        frames.append((opcode, data))
     204        start += offset + length
     205
     206    return frames, buf[start:]
     207
     208class WebSocketsProtocol(ProtocolWrapper):
     209    """
     210    Protocol which wraps another protocol to provide a WebSockets transport
     211    layer.
     212    """
     213
     214    buf = ""
     215    codec = None
     216
     217    def __init__(self, *args, **kwargs):
     218        ProtocolWrapper.__init__(self, *args, **kwargs)
     219        self.pending_frames = []
     220
     221    def parseFrames(self):
     222        """
     223        Find frames in incoming data and pass them to the underlying protocol.
     224        """
     225
     226        try:
     227            frames, self.buf = parse_hybi07_frames(self.buf)
     228        except WSException, wse:
     229            # Couldn't parse all the frames, something went wrong, let's bail.
     230            self.close(wse.args[0])
     231            return
     232
     233        for frame in frames:
     234            opcode, data = frame
     235            if opcode == NORMAL:
     236                # Business as usual. Decode the frame, if we have a decoder.
     237                if self.codec:
     238                    data = decoders[self.codec](data)
     239                # Pass the frame to the underlying protocol.
     240                ProtocolWrapper.dataReceived(self, data)
     241            elif opcode == CLOSE:
     242                # The other side wants us to close. I wonder why?
     243                reason, text = data
     244                log.msg("Closing connection: %r (%d)" % (text, reason))
     245
     246                # Close the connection.
     247                self.close()
     248
     249    def sendFrames(self):
     250        """
     251        Send all pending frames.
     252        """
     253
     254        for frame in self.pending_frames:
     255            # Encode the frame before sending it.
     256            if self.codec:
     257                frame = encoders[self.codec](frame)
     258            packet = make_hybi07_frame(frame)
     259            self.transport.write(packet)
     260        self.pending_frames = []
     261
     262    def dataReceived(self, data):
     263        self.buf += data
     264
     265        self.parseFrames()
     266
     267        # Kick any pending frames. This is needed because frames might have
     268        # started piling up early; we can get write()s from our protocol above
     269        # when they makeConnection() immediately, before our browser client
     270        # actually sends any data. In those cases, we need to manually kick
     271        # pending frames.
     272        if self.pending_frames:
     273            self.sendFrames()
     274
     275    def write(self, data):
     276        """
     277        Write to the transport.
     278
     279        This method will only be called by the underlying protocol.
     280        """
     281
     282        self.pending_frames.append(data)
     283        self.sendFrames()
     284
     285    def writeSequence(self, data):
     286        """
     287        Write a sequence of data to the transport.
     288
     289        This method will only be called by the underlying protocol.
     290        """
     291
     292        self.pending_frames.extend(data)
     293        self.sendFrames()
     294
     295    def close(self, reason=""):
     296        """
     297        Close the connection.
     298
     299        This includes telling the other side we're closing the connection.
     300
     301        If the other side didn't signal that the connection is being closed,
     302        then we might not see their last message, but since their last message
     303        should, according to the spec, be a simple acknowledgement, it
     304        shouldn't be a problem.
     305        """
     306
     307        # Send a closing frame. It's only polite. (And might keep the browser
     308        # from hanging.)
     309        frame = make_hybi07_frame(reason, opcode=0x8)
     310        self.transport.write(frame)
     311
     312        self.loseConnection()
     313
     314class WebSocketsFactory(WrappingFactory):
     315    """
     316    Factory which wraps another factory to provide WebSockets frames for all
     317    of its protocols.
     318
     319    This factory does not provide the HTTP headers required to perform a
     320    WebSockets handshake; see C{WebSocketsResource}.
     321    """
     322
     323    protocol = WebSocketsProtocol
     324
     325class WebSocketsResource(object):
     326
     327    implements(IResource)
     328
     329    isLeaf = True
     330
     331    def __init__(self, factory):
     332        self._factory = WebSocketsFactory(factory)
     333
     334    def getChildWithDefault(self, name, request):
     335        return NoResource("No such child resource.")
     336
     337    def putChild(self, path, child):
     338        pass
     339
     340    def render(self, request):
     341        """
     342        Render a request.
     343
     344        We're not actually rendering a request. We are secretly going to
     345        handle a WebSockets connection instead.
     346        """
     347
     348        # If we fail at all, we're gonna fail with 400 and no response.
     349        # You might want to pop open the RFC and read along.
     350        failed = False
     351
     352        if request.method != "GET":
     353            # 4.2.1.1 GET is required.
     354            failed = True
     355
     356        upgrade = request.getHeader("Upgrade")
     357        if upgrade is None or "websocket" not in upgrade.lower():
     358            # 4.2.1.3 Upgrade: WebSocket is required.
     359            failed = True
     360
     361        connection = request.getHeader("Connection")
     362        if connection is None or "upgrade" not in connection.lower():
     363            # 4.2.1.4 Connection: Upgrade is required.
     364            failed = True
     365
     366        key = request.getHeader("Sec-WebSocket-Key")
     367        if key is None:
     368            # 4.2.1.5 The challenge key is required.
     369            failed = True
     370
     371        version = request.getHeader("Sec-WebSocket-Version")
     372        if version is None or version != "13":
     373            # 4.2.1.6 Only version 13 works.
     374            failed = True
     375            # 4.4 Forward-compatible version checking.
     376            request.setHeader("Sec-WebSocket-Version", "13")
     377
     378        # Stash host and origin for those browsers that care about it.
     379        host = request.getHeader("Host")
     380        origin = request.getHeader("Origin")
     381
     382        # Check whether a codec is needed. WS calls this a "protocol" for
     383        # reasons I cannot fathom.
     384        protocol = request.getHeader("Sec-WebSocket-Protocol")
     385
     386        if protocol:
     387            if protocol not in encoders or protocol not in decoders:
     388                log.msg("Protocol %s is not implemented" % protocol)
     389                failed = True
     390
     391        if failed:
     392            request.setResponseCode(400)
     393            return ""
     394
     395        # We are going to finish this handshake. We will return a valid status
     396        # code.
     397        # 4.2.2.5.1 101 Switching Protocols
     398        request.setResponseCode(101)
     399        # 4.2.2.5.2 Upgrade: websocket
     400        request.setHeader("Upgrade", "WebSocket")
     401        # 4.2.2.5.3 Connection: Upgrade
     402        request.setHeader("Connection", "Upgrade")
     403        # 4.2.2.5.4 Response to the key challenge
     404        request.setHeader("Sec-WebSocket-Accept", make_accept(key))
     405
     406        # Provoke request into flushing headers and finishing the handshake.
     407        request.write("")
     408
     409        # And now take matters into our own hands. We shall manage the
     410        # transport's lifecycle.
     411        transport, request.transport = request.transport, None
     412
     413        # Connect the transport to our factory, and make things go. We need to
     414        # do some stupid stuff here; see #3204, which could fix it.
     415        protocol = self._factory.buildProtocol(transport.getPeer())
     416        transport.protocol = protocol
     417        protocol.makeConnection(transport)
     418
     419        return NOT_DONE_YET
  • twisted/web/test/test_websockets.py

     
     1from twisted.trial import unittest
     2
     3from twisted.web.websockets import (make_accept, mask, CLOSE, NORMAL, PING,
     4    PONG, parse_hybi07_frames)
     5
     6class TestKeys(unittest.TestCase):
     7
     8    def test_make_accept_rfc(self):
     9        """
     10        Test ``make_accept()`` using the keys listed in the RFC for HyBi-07
     11        through HyBi-10.
     12        """
     13
     14        key = "dGhlIHNhbXBsZSBub25jZQ=="
     15
     16        self.assertEqual(make_accept(key), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
     17
     18    def test_make_accept_wikipedia(self):
     19        """
     20        Test ``make_accept()`` using the keys listed on Wikipedia.
     21        """
     22
     23        key = "x3JJHMbDL1EzLkh9GBhXDw=="
     24
     25        self.assertEqual(make_accept(key), "HSmrc0sMlYUkAGmm5OPpG2HaGWk=")
     26
     27class TestHyBi07Helpers(unittest.TestCase):
     28    """
     29    HyBi-07 is best understood as a large family of helper functions which
     30    work together, somewhat dysfunctionally, to produce a mediocre
     31    Thanksgiving every other year.
     32    """
     33
     34    def test_mask_noop(self):
     35        key = "\x00\x00\x00\x00"
     36        self.assertEqual(mask("Test", key), "Test")
     37
     38    def test_mask_noop_long(self):
     39        key = "\x00\x00\x00\x00"
     40        self.assertEqual(mask("LongTest", key), "LongTest")
     41
     42    def test_parse_hybi07_unmasked_text(self):
     43        """
     44        From HyBi-10, 4.7.
     45        """
     46
     47        frame = "\x81\x05Hello"
     48        frames, buf = parse_hybi07_frames(frame)
     49        self.assertEqual(len(frames), 1)
     50        self.assertEqual(frames[0], (NORMAL, "Hello"))
     51        self.assertEqual(buf, "")
     52
     53    def test_parse_hybi07_masked_text(self):
     54        """
     55        From HyBi-10, 4.7.
     56        """
     57
     58        frame = "\x81\x857\xfa!=\x7f\x9fMQX"
     59        frames, buf = parse_hybi07_frames(frame)
     60        self.assertEqual(len(frames), 1)
     61        self.assertEqual(frames[0], (NORMAL, "Hello"))
     62        self.assertEqual(buf, "")
     63
     64    def test_parse_hybi07_unmasked_text_fragments(self):
     65        """
     66        We don't care about fragments. We are totally unfazed.
     67
     68        From HyBi-10, 4.7.
     69        """
     70
     71        frame = "\x01\x03Hel\x80\x02lo"
     72        frames, buf = parse_hybi07_frames(frame)
     73        self.assertEqual(len(frames), 2)
     74        self.assertEqual(frames[0], (NORMAL, "Hel"))
     75        self.assertEqual(frames[1], (NORMAL, "lo"))
     76        self.assertEqual(buf, "")
     77
     78    def test_parse_hybi07_ping(self):
     79        """
     80        From HyBi-10, 4.7.
     81        """
     82
     83        frame = "\x89\x05Hello"
     84        frames, buf = parse_hybi07_frames(frame)
     85        self.assertEqual(len(frames), 1)
     86        self.assertEqual(frames[0], (PING, "Hello"))
     87        self.assertEqual(buf, "")
     88
     89    def test_parse_hybi07_pong(self):
     90        """
     91        From HyBi-10, 4.7.
     92        """
     93
     94        frame = "\x8a\x05Hello"
     95        frames, buf = parse_hybi07_frames(frame)
     96        self.assertEqual(len(frames), 1)
     97        self.assertEqual(frames[0], (PONG, "Hello"))
     98        self.assertEqual(buf, "")
     99
     100    def test_parse_hybi07_close_empty(self):
     101        """
     102        A HyBi-07 close packet may have no body. In that case, it should use
     103        the generic error code 1000, and have no reason.
     104        """
     105
     106        frame = "\x88\x00"
     107        frames, buf = parse_hybi07_frames(frame)
     108        self.assertEqual(len(frames), 1)
     109        self.assertEqual(frames[0], (CLOSE, (1000, "No reason given")))
     110        self.assertEqual(buf, "")
     111
     112    def test_parse_hybi07_close_reason(self):
     113        """
     114        A HyBi-07 close packet must have its first two bytes be a numeric
     115        error code, and may optionally include trailing text explaining why
     116        the connection was closed.
     117        """
     118
     119        frame = "\x88\x0b\x03\xe8No reason"
     120        frames, buf = parse_hybi07_frames(frame)
     121        self.assertEqual(len(frames), 1)
     122        self.assertEqual(frames[0], (CLOSE, (1000, "No reason")))
     123        self.assertEqual(buf, "")
     124
     125    def test_parse_hybi07_partial_no_length(self):
     126        frame = "\x81"
     127        frames, buf = parse_hybi07_frames(frame)
     128        self.assertFalse(frames)
     129        self.assertEqual(buf, "\x81")
     130
     131    def test_parse_hybi07_partial_truncated_length_int(self):
     132        frame = "\x81\xfe"
     133        frames, buf = parse_hybi07_frames(frame)
     134        self.assertFalse(frames)
     135        self.assertEqual(buf, "\x81\xfe")
     136
     137    def test_parse_hybi07_partial_truncated_length_double(self):
     138        frame = "\x81\xff"
     139        frames, buf = parse_hybi07_frames(frame)
     140        self.assertFalse(frames)
     141        self.assertEqual(buf, "\x81\xff")
     142
     143    def test_parse_hybi07_partial_no_data(self):
     144        frame = "\x81\x05"
     145        frames, buf = parse_hybi07_frames(frame)
     146        self.assertFalse(frames)
     147        self.assertEqual(buf, "\x81\x05")
     148
     149    def test_parse_hybi07_partial_truncated_data(self):
     150        frame = "\x81\x05Hel"
     151        frames, buf = parse_hybi07_frames(frame)
     152        self.assertFalse(frames)
     153        self.assertEqual(buf, "\x81\x05Hel")